In [1]:
import os
import torch
import torchio as tio

# Set the base directory
base_path = '/home/jovyan/shared/data/PDGM/UCSF-PDGM-v5/UCSF-PDGM-v5'
os.chdir(base_path)

print(f"Working from: {os.getcwd()}")
print("Extracting subject IDs...")
print("=" * 50)

# Get all subject directories and extract IDs
subject_dirs = [d for d in os.listdir('.') if os.path.isdir(d) and 'UCSF-PDGM-' in d and '_nifti' in d]

# Extract clean subject IDs
subject_ids = []
for subject_dir in subject_dirs:
    # Remove '_nifti' suffix to get clean ID
    subject_id = subject_dir.replace('_nifti', '')
    subject_ids.append(subject_id)

# Sort the subject IDs
subject_ids.sort()

print(f"Found {len(subject_ids)} subjects:")
#print("=" * 30)
#for i, subject_id in enumerate(subject_ids, 1):
#    print(f"{i:2d}. {subject_id}")

#print(f"\nTotal subjects: {len(subject_ids)}")


Working from: /home/jovyan/shared/data/PDGM/UCSF-PDGM-v5/UCSF-PDGM-v5
Extracting subject IDs...
Found 501 subjects:


In [2]:
import torch
import torchio as tio

# Dictionary to store all the DTI FA images
dti_fa_images = {}

# Loop through each subject and load their DTI FA image
for subject_id in subject_ids:
    try:
        # Construct the file path
        folder_name = f"{subject_id}_nifti"
        file_name = f"{subject_id}_DTI_eddy_FA.nii.gz"
        file_path = os.path.join('.', folder_name, file_name)
        
        # Check if file exists
        if os.path.exists(file_path):
            # Load the DTI FA image
            dti_fa_image = tio.ScalarImage(file_path)
            
            # Store in dictionary with subject ID as key
            dti_fa_images[subject_id] = dti_fa_image
            
            #print(f"✓ Loaded DTI FA for {subject_id}")
        else:
            print(f"✗ File not found for {subject_id}: {file_path}")
            
    except Exception as e:
        print(f"✗ Error loading {subject_id}: {str(e)}")

print(f"\nSuccessfully loaded {len(dti_fa_images)} DTI FA images")
print(f"Subject IDs with loaded images: {list(dti_fa_images.keys())[:10]}...")



Successfully loaded 501 DTI FA images
Subject IDs with loaded images: ['UCSF-PDGM-0004', 'UCSF-PDGM-0005', 'UCSF-PDGM-0007', 'UCSF-PDGM-0008', 'UCSF-PDGM-0009', 'UCSF-PDGM-0010', 'UCSF-PDGM-0011', 'UCSF-PDGM-0012', 'UCSF-PDGM-0013', 'UCSF-PDGM-0014']...


In [3]:
import pandas as pd
import os

# Load the metadata
metadata_path = '/home/jovyan/shared/data/PDGM/UCSF-PDGM-metadata_v5.csv'
metadata_df = pd.read_csv(metadata_path)

print(f"Metadata shape: {metadata_df.shape}")
print(f"Columns: {list(metadata_df.columns)}")
print("\nFirst few rows:")
print(metadata_df.head())

# Check the subject ID column name
print(f"\nPossible subject ID columns: {[col for col in metadata_df.columns if 'subject' in col.lower() or 'id' in col.lower()]}")

target_fields = ['ID', 'Sex', 'Age at MRI', 'WHO CNS Grade']

metadata_df=metadata_df[target_fields]
# Fix metadata ID column by zero-padding to match the structure of the scans which is 4 digits 
metadata_df['ID'] = metadata_df['ID'].apply(
    lambda x: 'UCSF-PDGM-' + x.split('-')[-1].zfill(4)
)


print(f"Metadata shape: {metadata_df.shape}")
print(f"Columns: {list(metadata_df.columns)}")
print("\nFirst few rows:")
print(metadata_df.head())

Metadata shape: (501, 16)
Columns: ['ID', 'Sex', 'Age at MRI', 'WHO CNS Grade', 'Final pathologic diagnosis (WHO 2021)', 'MGMT status', 'MGMT index', '1p/19q', 'IDH', '1-dead 0-alive', 'OS', 'EOR', 'Biopsy prior to imaging', 'BraTS21 ID', 'BraTS21 Segmentation Cohort', 'BraTS21 MGMT Cohort']

First few rows:
              ID Sex  Age at MRI  WHO CNS Grade  \
0  UCSF-PDGM-004   M          66              4   
1  UCSF-PDGM-005   F          80              4   
2  UCSF-PDGM-007   M          70              4   
3  UCSF-PDGM-008   M          70              4   
4  UCSF-PDGM-009   F          68              4   

  Final pathologic diagnosis (WHO 2021)    MGMT status MGMT index   1p/19q  \
0            Glioblastoma, IDH-wildtype       negative          0  unknown   
1            Glioblastoma, IDH-wildtype  indeterminate    unknown  unknown   
2            Glioblastoma, IDH-wildtype  indeterminate    unknown  unknown   
3            Glioblastoma, IDH-wildtype       negative          0  unkn

In [4]:

subjects = []

for subject_id in dti_fa_images.keys():
    # Find the matching row in metadata
    row = metadata_df[metadata_df['ID'] == subject_id]

    if row.empty:
        print(f"Subject ID {subject_id} not found in metadata.")
        continue  # Skip if no match found

    # Extract metadata values
    sex = row['Sex'].values[0]
    age = row['Age at MRI'].values[0]
    grade = int(row['WHO CNS Grade'].values[0]) - 3  # Convert to int; Grade 2: -1, 3: 0, 4: 1
    if grade == -1:
        grade += 1 #Grade 2: 0, Grade 3: 0, Grade 4: 1

    # Create the subject dict
    subject_dict = {
        'dti_fa': dti_fa_images[subject_id],
        'subject_id': subject_id,
        'sex': sex,
        'age': age,
        'grade': grade
    }

    subject = tio.Subject(subject_dict)
    subjects.append(subject)

print(f"Created {len(subjects)} TorchIO subjects")

# Check one subject
print(f"\nExample subject: {subjects[0]}")
print(f"Keys: {subjects[0].keys()}")
print(f"DTI FA shape: {subjects[0]['dti_fa'].data.shape}")

Created 501 TorchIO subjects

Example subject: Subject(Keys: ('dti_fa', 'subject_id', 'sex', 'age', 'grade'); images: 1)
Keys: dict_keys(['dti_fa', 'subject_id', 'sex', 'age', 'grade'])
DTI FA shape: torch.Size([1, 240, 240, 155])


In [5]:
# Create SubjectsDataset

SubjectsDataset = tio.SubjectsDataset(subjects)

In [6]:
#Performing a manual check 
print("dti_fa:", SubjectsDataset[221]['dti_fa'])
print("subject_id:", SubjectsDataset[221]['subject_id'])
print("sex:", SubjectsDataset[221]['sex'])
print("age:", SubjectsDataset[221]['age'])
print("grade:", SubjectsDataset[20]['grade'])


dti_fa: ScalarImage(shape: (1, 240, 240, 155); spacing: (1.00, 1.00, 1.00); orientation: LPS+; dtype: torch.FloatTensor; memory: 34.1 MiB)
subject_id: UCSF-PDGM-0257
sex: M
age: 32
grade: 1


In [7]:
# This a dataframe of 5 Male/5 Female for Grade 2, 3, 4
# Train set subject IDs - extracted from UCSF-PDGM dataset
# Format: Full UCSF-PDGM- prefix with 4-digit numbers

selected_subject_ids = [
    "UCSF-PDGM-0399", "UCSF-PDGM-0389", "UCSF-PDGM-0137", "UCSF-PDGM-0093", "UCSF-PDGM-0026", 
    "UCSF-PDGM-0474", "UCSF-PDGM-0375", "UCSF-PDGM-0031", "UCSF-PDGM-0480", "UCSF-PDGM-0264",
    "UCSF-PDGM-0514", "UCSF-PDGM-0077", "UCSF-PDGM-0044", "UCSF-PDGM-0454", "UCSF-PDGM-0432", 
    "UCSF-PDGM-0069", "UCSF-PDGM-0023", "UCSF-PDGM-0205", "UCSF-PDGM-0234", "UCSF-PDGM-0390",
    "UCSF-PDGM-0338", "UCSF-PDGM-0095", "UCSF-PDGM-0427", "UCSF-PDGM-0396", "UCSF-PDGM-0128", 
    "UCSF-PDGM-0382", "UCSF-PDGM-0213", "UCSF-PDGM-0539", "UCSF-PDGM-0180", "UCSF-PDGM-0396",
    "UCSF-PDGM-0433", "UCSF-PDGM-0215", "UCSF-PDGM-0367", "UCSF-PDGM-0121", "UCSF-PDGM-0297", 
    "UCSF-PDGM-0461", "UCSF-PDGM-0298", "UCSF-PDGM-0341", "UCSF-PDGM-0018", "UCSF-PDGM-0236",
    "UCSF-PDGM-0178", "UCSF-PDGM-0168", "UCSF-PDGM-0273", "UCSF-PDGM-0413", "UCSF-PDGM-0455", 
    "UCSF-PDGM-0312", "UCSF-PDGM-0237", "UCSF-PDGM-0250", "UCSF-PDGM-0066", "UCSF-PDGM-0509",
    "UCSF-PDGM-0380", "UCSF-PDGM-0494", "UCSF-PDGM-0362", "UCSF-PDGM-0207", "UCSF-PDGM-0541", 
    "UCSF-PDGM-0240", "UCSF-PDGM-0104", "UCSF-PDGM-0127", "UCSF-PDGM-0491", "UCSF-PDGM-0210",
    "UCSF-PDGM-0188", "UCSF-PDGM-0426", "UCSF-PDGM-0123", "UCSF-PDGM-0379", "UCSF-PDGM-0453", 
    "UCSF-PDGM-0243", "UCSF-PDGM-0112", "UCSF-PDGM-0510", "UCSF-PDGM-0134", "UCSF-PDGM-0372",
    "UCSF-PDGM-0516", "UCSF-PDGM-0448", "UCSF-PDGM-0331", "UCSF-PDGM-0320", "UCSF-PDGM-0140", 
    "UCSF-PDGM-0103", "UCSF-PDGM-0008", "UCSF-PDGM-0045", "UCSF-PDGM-0495", "UCSF-PDGM-0371",
    "UCSF-PDGM-0456", "UCSF-PDGM-0156", "UCSF-PDGM-0013", "UCSF-PDGM-0321", "UCSF-PDGM-0158", 
    "UCSF-PDGM-0465", "UCSF-PDGM-0525", "UCSF-PDGM-0536", "UCSF-PDGM-0251", "UCSF-PDGM-0204",
    "UCSF-PDGM-0318", "UCSF-PDGM-0165", "UCSF-PDGM-0197", "UCSF-PDGM-0530", "UCSF-PDGM-0470", 
    "UCSF-PDGM-0116", "UCSF-PDGM-0447", "UCSF-PDGM-0420", "UCSF-PDGM-0022", "UCSF-PDGM-0194",
    "UCSF-PDGM-0187", "UCSF-PDGM-0346", "UCSF-PDGM-0203", "UCSF-PDGM-0521", "UCSF-PDGM-0126", 
    "UCSF-PDGM-0458", "UCSF-PDGM-0307", "UCSF-PDGM-0102", "UCSF-PDGM-0468", "UCSF-PDGM-0279",
    "UCSF-PDGM-0065", "UCSF-PDGM-0037", "UCSF-PDGM-0043", "UCSF-PDGM-0374", "UCSF-PDGM-0135", 
    "UCSF-PDGM-0336", "UCSF-PDGM-0340", "UCSF-PDGM-0029", "UCSF-PDGM-0238", "UCSF-PDGM-0378",
    "UCSF-PDGM-0487", "UCSF-PDGM-0287", "UCSF-PDGM-0170", "UCSF-PDGM-0508", "UCSF-PDGM-0532", 
    "UCSF-PDGM-0522", "UCSF-PDGM-0198", "UCSF-PDGM-0506", "UCSF-PDGM-0024", "UCSF-PDGM-0162",
    "UCSF-PDGM-0429", "UCSF-PDGM-0223", "UCSF-PDGM-0489", "UCSF-PDGM-0078", "UCSF-PDGM-0075", 
    "UCSF-PDGM-0173", "UCSF-PDGM-0157", "UCSF-PDGM-0233", "UCSF-PDGM-0387", "UCSF-PDGM-0467",
    "UCSF-PDGM-0436", "UCSF-PDGM-0186", "UCSF-PDGM-0119", "UCSF-PDGM-0099", "UCSF-PDGM-0323", 
    "UCSF-PDGM-0020", "UCSF-PDGM-0042", "UCSF-PDGM-0291", "UCSF-PDGM-0086", "UCSF-PDGM-0512",
    "UCSF-PDGM-0449", "UCSF-PDGM-0150", "UCSF-PDGM-0183", "UCSF-PDGM-0457", "UCSF-PDGM-0088", 
    "UCSF-PDGM-0159", "UCSF-PDGM-0272", "UCSF-PDGM-0423", "UCSF-PDGM-0393", "UCSF-PDGM-0476",
    "UCSF-PDGM-0245", "UCSF-PDGM-0460", "UCSF-PDGM-0247", "UCSF-PDGM-0499", "UCSF-PDGM-0057", 
    "UCSF-PDGM-0083", "UCSF-PDGM-0445", "UCSF-PDGM-0364", "UCSF-PDGM-0144", "UCSF-PDGM-0383",
    "UCSF-PDGM-0106", "UCSF-PDGM-0353", "UCSF-PDGM-0483", "UCSF-PDGM-0053", "UCSF-PDGM-0517", 
    "UCSF-PDGM-0524", "UCSF-PDGM-0410", "UCSF-PDGM-0477", "UCSF-PDGM-0231", "UCSF-PDGM-0059",
    "UCSF-PDGM-0089", "UCSF-PDGM-0070", "UCSF-PDGM-0519", "UCSF-PDGM-0155", "UCSF-PDGM-0409", 
    "UCSF-PDGM-0497", "UCSF-PDGM-0261", "UCSF-PDGM-0440", "UCSF-PDGM-0118", "UCSF-PDGM-0282",
    "UCSF-PDGM-0252", "UCSF-PDGM-0406", "UCSF-PDGM-0490", "UCSF-PDGM-0164", "UCSF-PDGM-0400", 
    "UCSF-PDGM-0357", "UCSF-PDGM-0209", "UCSF-PDGM-0300", "UCSF-PDGM-0091", "UCSF-PDGM-0344",
    "UCSF-PDGM-0347", "UCSF-PDGM-0475", "UCSF-PDGM-0469", "UCSF-PDGM-0246", "UCSF-PDGM-0309", 
    "UCSF-PDGM-0039", "UCSF-PDGM-0071", "UCSF-PDGM-0174", "UCSF-PDGM-0397", "UCSF-PDGM-0011",
    "UCSF-PDGM-0132", "UCSF-PDGM-0471", "UCSF-PDGM-0441", "UCSF-PDGM-0161", "UCSF-PDGM-0016", 
    "UCSF-PDGM-0105", "UCSF-PDGM-0479", "UCSF-PDGM-0425", "UCSF-PDGM-0193", "UCSF-PDGM-0412",
    "UCSF-PDGM-0196", "UCSF-PDGM-0094", "UCSF-PDGM-0424", "UCSF-PDGM-0444", "UCSF-PDGM-0358", 
    "UCSF-PDGM-0146", "UCSF-PDGM-0286", "UCSF-PDGM-0534", "UCSF-PDGM-0349", "UCSF-PDGM-0416",
    "UCSF-PDGM-0329", "UCSF-PDGM-0526", "UCSF-PDGM-0130", "UCSF-PDGM-0087", "UCSF-PDGM-0360", 
    "UCSF-PDGM-0265", "UCSF-PDGM-0021", "UCSF-PDGM-0004", "UCSF-PDGM-0419", "UCSF-PDGM-0384",
    "UCSF-PDGM-0411", "UCSF-PDGM-0498", "UCSF-PDGM-0253", "UCSF-PDGM-0431", "UCSF-PDGM-0527", 
    "UCSF-PDGM-0451", "UCSF-PDGM-0032", "UCSF-PDGM-0502", "UCSF-PDGM-0513", "UCSF-PDGM-0228",
    "UCSF-PDGM-0302", "UCSF-PDGM-0014", "UCSF-PDGM-0274", "UCSF-PDGM-0147", "UCSF-PDGM-0283", 
    "UCSF-PDGM-0027", "UCSF-PDGM-0227", "UCSF-PDGM-0281", "UCSF-PDGM-0270", "UCSF-PDGM-0501",
    "UCSF-PDGM-0266", "UCSF-PDGM-0096", "UCSF-PDGM-0143", "UCSF-PDGM-0064", "UCSF-PDGM-0131", 
    "UCSF-PDGM-0538", "UCSF-PDGM-0462", "UCSF-PDGM-0067", "UCSF-PDGM-0417", "UCSF-PDGM-0202",
    "UCSF-PDGM-0407", "UCSF-PDGM-0368", "UCSF-PDGM-0342", "UCSF-PDGM-0260", "UCSF-PDGM-0225", 
    "UCSF-PDGM-0313", "UCSF-PDGM-0079", "UCSF-PDGM-0330", "UCSF-PDGM-0343", "UCSF-PDGM-0232",
    "UCSF-PDGM-0074", "UCSF-PDGM-0268", "UCSF-PDGM-0025", "UCSF-PDGM-0369", "UCSF-PDGM-0418", 
    "UCSF-PDGM-0185", "UCSF-PDGM-0055", "UCSF-PDGM-0084", "UCSF-PDGM-0392", "UCSF-PDGM-0348",
    "UCSF-PDGM-0166", "UCSF-PDGM-0114", "UCSF-PDGM-0206", "UCSF-PDGM-0409", "UCSF-PDGM-0136", 
    "UCSF-PDGM-0311", "UCSF-PDGM-0500", "UCSF-PDGM-0335", "UCSF-PDGM-0345", "UCSF-PDGM-0405",
    "UCSF-PDGM-0355", "UCSF-PDGM-0142", "UCSF-PDGM-0276", "UCSF-PDGM-0485", "UCSF-PDGM-0256", 
    "UCSF-PDGM-0459", "UCSF-PDGM-0373", "UCSF-PDGM-0529", "UCSF-PDGM-0214", "UCSF-PDGM-0484",
    "UCSF-PDGM-0398", "UCSF-PDGM-0107", "UCSF-PDGM-0308", "UCSF-PDGM-0325", "UCSF-PDGM-0169", 
    "UCSF-PDGM-0277", "UCSF-PDGM-0446", "UCSF-PDGM-0450", "UCSF-PDGM-0350", "UCSF-PDGM-0167",
    "UCSF-PDGM-0030", "UCSF-PDGM-0048", "UCSF-PDGM-0518", "UCSF-PDGM-0139", "UCSF-PDGM-0101", 
    "UCSF-PDGM-0422", "UCSF-PDGM-0520", "UCSF-PDGM-0080", "UCSF-PDGM-0438", "UCSF-PDGM-0452",
    "UCSF-PDGM-0366", "UCSF-PDGM-0258", "UCSF-PDGM-0473", "UCSF-PDGM-0381", "UCSF-PDGM-0428", 
    "UCSF-PDGM-0535", "UCSF-PDGM-0242", "UCSF-PDGM-0063", "UCSF-PDGM-0303", "UCSF-PDGM-0082",
    "UCSF-PDGM-0129", "UCSF-PDGM-0352", "UCSF-PDGM-0334", "UCSF-PDGM-0505", "UCSF-PDGM-0153", 
    "UCSF-PDGM-0288", "UCSF-PDGM-0326", "UCSF-PDGM-0068", "UCSF-PDGM-0435", "UCSF-PDGM-0402",
    "UCSF-PDGM-0442", "UCSF-PDGM-0503", "UCSF-PDGM-0363", "UCSF-PDGM-0537", "UCSF-PDGM-0090", 
    "UCSF-PDGM-0365", "UCSF-PDGM-0176", "UCSF-PDGM-0466", "UCSF-PDGM-0141", "UCSF-PDGM-0005",
    "UCSF-PDGM-0290", "UCSF-PDGM-0305", "UCSF-PDGM-0280", "UCSF-PDGM-0012", "UCSF-PDGM-0439", 
    "UCSF-PDGM-0195", "UCSF-PDGM-0433", "UCSF-PDGM-0332", "UCSF-PDGM-0415", "UCSF-PDGM-0007",
    "UCSF-PDGM-0269", "UCSF-PDGM-0149", "UCSF-PDGM-0511", "UCSF-PDGM-0464", "UCSF-PDGM-0481", 
    "UCSF-PDGM-0111", "UCSF-PDGM-0295", "UCSF-PDGM-0010", "UCSF-PDGM-0377", "UCSF-PDGM-0058",
    "UCSF-PDGM-0316", "UCSF-PDGM-0437", "UCSF-PDGM-0036", "UCSF-PDGM-0404", "UCSF-PDGM-0172", 
    "UCSF-PDGM-0401", "UCSF-PDGM-0085", "UCSF-PDGM-0201", "UCSF-PDGM-0122", "UCSF-PDGM-0035",
    "UCSF-PDGM-0015", "UCSF-PDGM-0403", "UCSF-PDGM-0376", "UCSF-PDGM-0327", "UCSF-PDGM-0108", 
    "UCSF-PDGM-0394", "UCSF-PDGM-0414", "UCSF-PDGM-0163", "UCSF-PDGM-0391", "UCSF-PDGM-0235"
]
Subjects_train = [
    subj for subj in subjects if subj['subject_id'] in selected_subject_ids
]
SubjectsDataset_train = tio.SubjectsDataset(Subjects_train)



In [8]:
#Performing a manual check 
#for i, subject in enumerate(SubjectsDataset_train):
#    print(f"Subject {i} grade: {subject['grade']}")


In [9]:
from torch.utils.data import DataLoader

def collate_subjects(batch):
    images = []
    labels = []
    for s in batch:
        img = s['dti_fa'].data  # shape: (1, H, W, D)
        img = img.permute(0, 3, 1, 2)  # shape: (1, D, H, W)
        images.append(img)
        labels.append(s['grade'])  # fix here
    images = torch.stack(images)  # (B, 1, D, H, W)
    labels = torch.tensor(labels)  # (B,)
    return images, labels

# DataLoader using the custom collate function
train_loader = DataLoader(
    SubjectsDataset_train, 
    batch_size=3, 
    shuffle=True,
    collate_fn=collate_subjects
)

In [10]:
for images, labels in train_loader:
    print(labels.dtype)  # This will print torch.int64
    break  # Just print for the first batch and stop

torch.int64


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

os.chdir('/home/jovyan/git-gliograde/Deep_Learning')
from resnet import generate_model

model = generate_model(
    model_depth=10,
    n_classes=2,          # We are using a 2-class solution now 
    n_input_channels=1    # set to 1 for grayscale or 3 for RGB
)

#print(model)
os.chdir('/home/jovyan/shared/data/PDGM/UCSF-PDGM-v5/UCSF-PDGM-v5')

In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
import torch.nn as nn
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.0001) 
#Grade 0 weight: 397 / (2 × 79) = 397 / 158 = 2.531,
#Grade 1 weight: 397 / (2 × 318) = 397 / 636 = 0.628
class_weights = torch.FloatTensor([2.531, 0.628])  
criterion = nn.CrossEntropyLoss()

In [13]:
torch.cuda.is_available()

True

In [14]:
# Training loop

for epoch in range(10):
    torch.cuda.empty_cache()
    model.train()
    running_loss = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss:.4f}")

Epoch 1, Loss: 66.2596
Epoch 2, Loss: 61.8307
Epoch 3, Loss: 64.0106
Epoch 4, Loss: 63.5565
Epoch 5, Loss: 61.1900
Epoch 6, Loss: 61.3254
Epoch 7, Loss: 63.9747
Epoch 8, Loss: 57.7605
Epoch 9, Loss: 59.7792
Epoch 10, Loss: 55.1815


In [15]:
# Simple final evaluation
print("\n" + "="*30)
print("FINAL EVALUATION")
print("="*30)

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate and display results
correct = sum([pred == label for pred, label in zip(all_preds, all_labels)])
total = len(all_labels)
accuracy = correct / total

print(f"Training Accuracy: {correct}/{total} = {accuracy:.3f} ({accuracy*100:.1f}%)")


FINAL EVALUATION
Training Accuracy: 187/397 = 0.471 (47.1%)


In [17]:
print("Predictions for all subjects:")
print("=" * 50)
print("Sample | True Grade | Predicted Grade")
print("-" * 50)

all_predictions = []
all_labels = []
sample_count = 0

with torch.no_grad():
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        
        for i in range(len(labels)):
            sample_count += 1
            true_grade = labels[i].item()
            pred_grade = predictions[i].item()
            
            print(f"{sample_count:6d} | {true_grade:10d} | {pred_grade:15d}")
            
            all_predictions.append(pred_grade)
            all_labels.append(true_grade)

print("=" * 50)

# Summary
unique_predictions = set(all_predictions)
unique_labels = set(all_labels)

print(f"\nSUMMARY:")
print(f"Total samples: {len(all_labels)}")
print(f"True labels used: {sorted(unique_labels)}")
print(f"Predicted labels: {sorted(unique_predictions)}")

# Count predictions
from collections import Counter
pred_counts = Counter(all_predictions)
label_counts = Counter(all_labels)

print(f"\nTrue label distribution:")
for grade in sorted(label_counts.keys()):
    print(f"  Grade {grade}: {label_counts[grade]} samples")

print(f"\nPredicted label distribution:")
for grade in sorted(pred_counts.keys()):
    print(f"  Grade {grade}: {pred_counts[grade]} samples")

# Check if predicting only one class
if len(unique_predictions) == 1:
    print(f"\n🚨 WARNING: Model is predicting ONLY Grade {list(unique_predictions)[0]} for ALL samples!")
    print("This means the model hasn't learned to distinguish between classes.")
else:
    print(f"\n✓ Model is predicting {len(unique_predictions)} different classes")

# Accuracy
correct = sum([p == l for p, l in zip(all_predictions, all_labels)])
accuracy = correct / len(all_labels)
print(f"\nAccuracy: {correct}/{len(all_labels)} = {accuracy:.3f} ({accuracy*100:.1f}%)")


Predictions for all subjects:
Sample | True Grade | Predicted Grade
--------------------------------------------------
     1 |          1 |               0
     2 |          1 |               1
     3 |          1 |               0
     4 |          1 |               0
     5 |          1 |               1
     6 |          1 |               1
     7 |          1 |               1
     8 |          1 |               1
     9 |          1 |               0
    10 |          0 |               0
    11 |          1 |               0
    12 |          1 |               1
    13 |          0 |               1
    14 |          1 |               1
    15 |          1 |               0
    16 |          1 |               0
    17 |          1 |               0
    18 |          1 |               1
    19 |          1 |               0
    20 |          1 |               0
    21 |          1 |               1
    22 |          1 |               1
    23 |          1 |               1
    24 

In [19]:
# Simple confusion matrix code

import torch
import numpy as np
from sklearn.metrics import confusion_matrix

# Get predictions from your model
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for images, labels in train_loader:
        images = images.to(device)
        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Create confusion matrix
cm = confusion_matrix(all_labels, all_predictions)

# Print results
print("Confusion Matrix:")
print("         Predicted")
print("         0    1")
print(f"Actual 0 {cm[0,0]:3d}  {cm[0,1]:3d}")
print(f"       1 {cm[1,0]:3d}  {cm[1,1]:3d}")

# Calculate accuracy
total = len(all_labels)
correct = cm[0,0] + cm[1,1]
accuracy = correct / total

print(f"\nAccuracy: {correct}/{total} = {accuracy:.3f} ({accuracy*100:.1f}%)")

Confusion Matrix:
         Predicted
         0    1
Actual 0  70    9
       1 201  117

Accuracy: 187/397 = 0.471 (47.1%)
