In [1]:
import os
import torch
import torchio as tio

# Set the base directory
base_path = '/home/jovyan/shared/data/PDGM/UCSF-PDGM-v5/UCSF-PDGM-v5'
os.chdir(base_path)

print(f"Working from: {os.getcwd()}")
print("Extracting subject IDs...")
print("=" * 50)

# Get all subject directories and extract IDs
subject_dirs = [d for d in os.listdir('.') if os.path.isdir(d) and 'UCSF-PDGM-' in d and '_nifti' in d]

# Extract clean subject IDs
subject_ids = []
for subject_dir in subject_dirs:
    # Remove '_nifti' suffix to get clean ID
    subject_id = subject_dir.replace('_nifti', '')
    subject_ids.append(subject_id)

# Sort the subject IDs
subject_ids.sort()

print(f"Found {len(subject_ids)} subjects:")
print("=" * 30)
for i, subject_id in enumerate(subject_ids, 1):
    print(f"{i:2d}. {subject_id}")

print(f"\nTotal subjects: {len(subject_ids)}")


Working from: /home/jovyan/shared/data/PDGM/UCSF-PDGM-v5/UCSF-PDGM-v5
Extracting subject IDs...
Found 501 subjects:
 1. UCSF-PDGM-0004
 2. UCSF-PDGM-0005
 3. UCSF-PDGM-0007
 4. UCSF-PDGM-0008
 5. UCSF-PDGM-0009
 6. UCSF-PDGM-0010
 7. UCSF-PDGM-0011
 8. UCSF-PDGM-0012
 9. UCSF-PDGM-0013
10. UCSF-PDGM-0014
11. UCSF-PDGM-0015
12. UCSF-PDGM-0016
13. UCSF-PDGM-0017
14. UCSF-PDGM-0018
15. UCSF-PDGM-0019
16. UCSF-PDGM-0020
17. UCSF-PDGM-0021
18. UCSF-PDGM-0022
19. UCSF-PDGM-0023
20. UCSF-PDGM-0024
21. UCSF-PDGM-0025
22. UCSF-PDGM-0026
23. UCSF-PDGM-0027
24. UCSF-PDGM-0029
25. UCSF-PDGM-0030
26. UCSF-PDGM-0031
27. UCSF-PDGM-0032
28. UCSF-PDGM-0033
29. UCSF-PDGM-0034
30. UCSF-PDGM-0035
31. UCSF-PDGM-0036
32. UCSF-PDGM-0037
33. UCSF-PDGM-0038
34. UCSF-PDGM-0039
35. UCSF-PDGM-0040
36. UCSF-PDGM-0041
37. UCSF-PDGM-0042
38. UCSF-PDGM-0043
39. UCSF-PDGM-0044
40. UCSF-PDGM-0045
41. UCSF-PDGM-0046
42. UCSF-PDGM-0047
43. UCSF-PDGM-0048
44. UCSF-PDGM-0049
45. UCSF-PDGM-0050
46. UCSF-PDGM-0053
47. UCSF-P

In [2]:
import torch
import torchio as tio

# Dictionary to store all the DTI FA images
dti_fa_images = {}

# Loop through each subject and load their DTI FA image
for subject_id in subject_ids:
    try:
        # Construct the file path
        folder_name = f"{subject_id}_nifti"
        file_name = f"{subject_id}_DTI_eddy_FA.nii.gz"
        file_path = os.path.join('.', folder_name, file_name)
        
        # Check if file exists
        if os.path.exists(file_path):
            # Load the DTI FA image
            dti_fa_image = tio.ScalarImage(file_path)
            
            # Store in dictionary with subject ID as key
            dti_fa_images[subject_id] = dti_fa_image
            
            print(f"✓ Loaded DTI FA for {subject_id}")
        else:
            print(f"✗ File not found for {subject_id}: {file_path}")
            
    except Exception as e:
        print(f"✗ Error loading {subject_id}: {str(e)}")

print(f"\nSuccessfully loaded {len(dti_fa_images)} DTI FA images")
print(f"Subject IDs with loaded images: {list(dti_fa_images.keys())[:10]}...")


✓ Loaded DTI FA for UCSF-PDGM-0004
✓ Loaded DTI FA for UCSF-PDGM-0005
✓ Loaded DTI FA for UCSF-PDGM-0007
✓ Loaded DTI FA for UCSF-PDGM-0008
✓ Loaded DTI FA for UCSF-PDGM-0009
✓ Loaded DTI FA for UCSF-PDGM-0010
✓ Loaded DTI FA for UCSF-PDGM-0011
✓ Loaded DTI FA for UCSF-PDGM-0012
✓ Loaded DTI FA for UCSF-PDGM-0013
✓ Loaded DTI FA for UCSF-PDGM-0014
✓ Loaded DTI FA for UCSF-PDGM-0015
✓ Loaded DTI FA for UCSF-PDGM-0016
✓ Loaded DTI FA for UCSF-PDGM-0017
✓ Loaded DTI FA for UCSF-PDGM-0018
✓ Loaded DTI FA for UCSF-PDGM-0019
✓ Loaded DTI FA for UCSF-PDGM-0020
✓ Loaded DTI FA for UCSF-PDGM-0021
✓ Loaded DTI FA for UCSF-PDGM-0022
✓ Loaded DTI FA for UCSF-PDGM-0023
✓ Loaded DTI FA for UCSF-PDGM-0024
✓ Loaded DTI FA for UCSF-PDGM-0025
✓ Loaded DTI FA for UCSF-PDGM-0026
✓ Loaded DTI FA for UCSF-PDGM-0027
✓ Loaded DTI FA for UCSF-PDGM-0029
✓ Loaded DTI FA for UCSF-PDGM-0030
✓ Loaded DTI FA for UCSF-PDGM-0031
✓ Loaded DTI FA for UCSF-PDGM-0032
✓ Loaded DTI FA for UCSF-PDGM-0033
✓ Loaded DTI FA for 

In [3]:
import pandas as pd
import os

# Load the metadata
metadata_path = '/home/jovyan/shared/data/PDGM/UCSF-PDGM-metadata_v5.csv'
metadata_df = pd.read_csv(metadata_path)

print(f"Metadata shape: {metadata_df.shape}")
print(f"Columns: {list(metadata_df.columns)}")
print("\nFirst few rows:")
print(metadata_df.head())

# Check the subject ID column name
print(f"\nPossible subject ID columns: {[col for col in metadata_df.columns if 'subject' in col.lower() or 'id' in col.lower()]}")

target_fields = ['ID', 'Sex', 'Age at MRI', 'WHO CNS Grade']

metadata_df=metadata_df[target_fields]
# Fix metadata ID column by zero-padding to match the structure of the scans which is 4 digits 
metadata_df['ID'] = metadata_df['ID'].apply(
    lambda x: 'UCSF-PDGM-' + x.split('-')[-1].zfill(4)
)


print(f"Metadata shape: {metadata_df.shape}")
print(f"Columns: {list(metadata_df.columns)}")
print("\nFirst few rows:")
print(metadata_df.head())

Metadata shape: (501, 16)
Columns: ['ID', 'Sex', 'Age at MRI', 'WHO CNS Grade', 'Final pathologic diagnosis (WHO 2021)', 'MGMT status', 'MGMT index', '1p/19q', 'IDH', '1-dead 0-alive', 'OS', 'EOR', 'Biopsy prior to imaging', 'BraTS21 ID', 'BraTS21 Segmentation Cohort', 'BraTS21 MGMT Cohort']

First few rows:
              ID Sex  Age at MRI  WHO CNS Grade  \
0  UCSF-PDGM-004   M          66              4   
1  UCSF-PDGM-005   F          80              4   
2  UCSF-PDGM-007   M          70              4   
3  UCSF-PDGM-008   M          70              4   
4  UCSF-PDGM-009   F          68              4   

  Final pathologic diagnosis (WHO 2021)    MGMT status MGMT index   1p/19q  \
0            Glioblastoma, IDH-wildtype       negative          0  unknown   
1            Glioblastoma, IDH-wildtype  indeterminate    unknown  unknown   
2            Glioblastoma, IDH-wildtype  indeterminate    unknown  unknown   
3            Glioblastoma, IDH-wildtype       negative          0  unkn

In [4]:

subjects = []

for subject_id in dti_fa_images.keys():
    # Find the matching row in metadata
    row = metadata_df[metadata_df['ID'] == subject_id]

    if row.empty:
        print(f"Subject ID {subject_id} not found in metadata.")
        continue  # Skip if no match found

    # Extract metadata values
    sex = row['Sex'].values[0]
    age = row['Age at MRI'].values[0]
    grade = int(row['WHO CNS Grade'].values[0]) - 2  # Convert to int and subtract 2

    # Create the subject dict
    subject_dict = {
        'dti_fa': dti_fa_images[subject_id],
        'subject_id': subject_id,
        'sex': sex,
        'age': age,
        'grade': grade
    }

    subject = tio.Subject(subject_dict)
    subjects.append(subject)

print(f"Created {len(subjects)} TorchIO subjects")

# Check one subject
print(f"\nExample subject: {subjects[0]}")
print(f"Keys: {subjects[0].keys()}")
print(f"DTI FA shape: {subjects[0]['dti_fa'].data.shape}")

Created 501 TorchIO subjects

Example subject: Subject(Keys: ('dti_fa', 'subject_id', 'sex', 'age', 'grade'); images: 1)
Keys: dict_keys(['dti_fa', 'subject_id', 'sex', 'age', 'grade'])
DTI FA shape: torch.Size([1, 240, 240, 155])


In [5]:
# Create SubjectsDataset

SubjectsDataset = tio.SubjectsDataset(subjects)

In [6]:
#Performing a manual check 
print("dti_fa:", SubjectsDataset[221]['dti_fa'])
print("subject_id:", SubjectsDataset[221]['subject_id'])
print("sex:", SubjectsDataset[221]['sex'])
print("age:", SubjectsDataset[221]['age'])
print("grade:", SubjectsDataset[221]['grade'])


dti_fa: ScalarImage(shape: (1, 240, 240, 155); spacing: (1.00, 1.00, 1.00); orientation: LPS+; dtype: torch.FloatTensor; memory: 34.1 MiB)
subject_id: UCSF-PDGM-0257
sex: M
age: 32
grade: 0


In [7]:
# This a dataframe of 5 Male/5 Female for Grade 2, 3, 4
selected_subject_ids = [
    'UCSF-PDGM-0004', 'UCSF-PDGM-0005', 'UCSF-PDGM-0241', 'UCSF-PDGM-0231', 'UCSF-PDGM-0327',
    'UCSF-PDGM-0305', 'UCSF-PDGM-0007', 'UCSF-PDGM-0009', 'UCSF-PDGM-0243', 'UCSF-PDGM-0254',
    'UCSF-PDGM-0439', 'UCSF-PDGM-0351', 'UCSF-PDGM-0008', 'UCSF-PDGM-0011', 'UCSF-PDGM-0249',
    'UCSF-PDGM-0268', 'UCSF-PDGM-0444', 'UCSF-PDGM-0438', 'UCSF-PDGM-0010', 'UCSF-PDGM-0012',
    'UCSF-PDGM-0251', 'UCSF-PDGM-0274', 'UCSF-PDGM-0448', 'UCSF-PDGM-0442', 'UCSF-PDGM-0015',
    'UCSF-PDGM-0013', 'UCSF-PDGM-0262', 'UCSF-PDGM-0277', 'UCSF-PDGM-0456', 'UCSF-PDGM-0443',
    'UCSF-PDGM-0017', 'UCSF-PDGM-0014', 'UCSF-PDGM-0263', 'UCSF-PDGM-0282', 'UCSF-PDGM-0465',
    'UCSF-PDGM-0446', 'UCSF-PDGM-0018', 'UCSF-PDGM-0016', 'UCSF-PDGM-0272', 'UCSF-PDGM-0285',
    'UCSF-PDGM-0478', 'UCSF-PDGM-0475', 'UCSF-PDGM-0020', 'UCSF-PDGM-0019', 'UCSF-PDGM-0499',
    'UCSF-PDGM-0326', 'UCSF-PDGM-0483', 'UCSF-PDGM-0476', 'UCSF-PDGM-0021', 'UCSF-PDGM-0022',
    'UCSF-PDGM-0500', 'UCSF-PDGM-0357', 'UCSF-PDGM-0485', 'UCSF-PDGM-0477', 'UCSF-PDGM-0024',
    'UCSF-PDGM-0023', 'UCSF-PDGM-0501', 'UCSF-PDGM-0367', 'UCSF-PDGM-0490', 'UCSF-PDGM-0540'
]
Subjects_small = [
    subj for subj in subjects if subj['subject_id'] in selected_subject_ids
]
SubjectsDataset_small = tio.SubjectsDataset(Subjects_small)


In [8]:
#Performing a manual check 
print("dti_fa:", SubjectsDataset_small[20]['dti_fa'])
print("subject_id:", SubjectsDataset_small[20]['subject_id'])
print("sex:", SubjectsDataset_small[20]['sex'])
print("age:", SubjectsDataset_small[20]['age'])
print("grade:", SubjectsDataset_small[20]['grade'])

dti_fa: ScalarImage(shape: (1, 240, 240, 155); spacing: (1.00, 1.00, 1.00); orientation: LPS+; dtype: torch.FloatTensor; memory: 34.1 MiB)
subject_id: UCSF-PDGM-0231
sex: F
age: 36
grade: 1


In [9]:
def collate_subjects(batch):
    images = torch.stack([s['dti_fa'].data for s in batch])  # (B, 1, D, H, W)
    labels = torch.tensor([s['grade'] for s in batch])        # (B,)
    return images, labels

In [10]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    SubjectsDataset_small, 
    batch_size=8, 
    shuffle=True,
    collate_fn=collate_subjects
)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Simple3DCNN(nn.Module):
    def __init__(self, num_classes=3):
        super(Simple3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(1, 8, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool3d(2)
        
        self.conv2 = nn.Conv3d(8, 16, 3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool3d(2)
        
        self.gap = nn.AdaptiveAvgPool3d((1,1,1))
        self.fc1 = nn.Linear(16, 64)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.gap(x)
        x = torch.flatten(x, 1)
        x = self.relu3(self.fc1(x))
        return self.fc2(x)

In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Simple3DCNN(num_classes=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [18]:
for epoch in range(10):
    model.train()
    running_loss = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss:.4f}")

Epoch 1, Loss: 8.8214
Epoch 2, Loss: 8.8283
Epoch 3, Loss: 8.7867
Epoch 4, Loss: 8.8045
Epoch 5, Loss: 8.8150
Epoch 6, Loss: 8.7950
Epoch 7, Loss: 8.8069
Epoch 8, Loss: 8.7917
Epoch 9, Loss: 8.7930
Epoch 10, Loss: 8.7920


In [19]:
# Simple final evaluation
print("\n" + "="*30)
print("FINAL EVALUATION")
print("="*30)

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate and display results
correct = sum([pred == label for pred, label in zip(all_preds, all_labels)])
total = len(all_labels)
accuracy = correct / total

print(f"Training Accuracy: {correct}/{total} = {accuracy:.3f} ({accuracy*100:.1f}%)")
print("✓ Pipeline working!" if accuracy > 0.5 else "⚠ Check pipeline")


FINAL EVALUATION
Training Accuracy: 20/60 = 0.333 (33.3%)
⚠ Check pipeline


In [20]:
# Simple prediction analysis - just print all predictions vs true labels

model.eval()
all_predictions = []
all_labels = []
sample_count = 0

print("Predictions for all subjects:")
print("=" * 50)
print("Sample | True Grade | Predicted Grade")
print("-" * 35)

with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        
        # Print each sample in this batch
        for i in range(len(labels)):
            sample_count += 1
            true_grade = labels[i].item()
            pred_grade = predictions[i].item()
            
            print(f"{sample_count:6d} | {true_grade:10d} | {pred_grade:15d}")
            
            all_predictions.append(pred_grade)
            all_labels.append(true_grade)

print("=" * 50)

# Summary
unique_predictions = set(all_predictions)
unique_labels = set(all_labels)

print(f"\nSUMMARY:")
print(f"Total samples: {len(all_labels)}")
print(f"True labels used: {sorted(unique_labels)}")
print(f"Predicted labels: {sorted(unique_predictions)}")

# Count predictions
from collections import Counter
pred_counts = Counter(all_predictions)
label_counts = Counter(all_labels)

print(f"\nTrue label distribution:")
for grade in sorted(label_counts.keys()):
    print(f"  Grade {grade}: {label_counts[grade]} samples")

print(f"\nPredicted label distribution:")
for grade in sorted(pred_counts.keys()):
    print(f"  Grade {grade}: {pred_counts[grade]} samples")

# Check if predicting only one class
if len(unique_predictions) == 1:
    print(f"\n🚨 WARNING: Model is predicting ONLY Grade {list(unique_predictions)[0]} for ALL samples!")
    print("This means the model hasn't learned to distinguish between classes.")
else:
    print(f"\n✓ Model is predicting {len(unique_predictions)} different classes")

# Accuracy
correct = sum([p == l for p, l in zip(all_predictions, all_labels)])
accuracy = correct / len(all_labels)
print(f"\nAccuracy: {correct}/{len(all_labels)} = {accuracy:.3f} ({accuracy*100:.1f}%)")

Predictions for all subjects:
Sample | True Grade | Predicted Grade
-----------------------------------
     1 |          2 |               0
     2 |          1 |               0
     3 |          0 |               0
     4 |          1 |               0
     5 |          2 |               0
     6 |          0 |               0
     7 |          2 |               0
     8 |          2 |               0
     9 |          0 |               0
    10 |          1 |               0
    11 |          1 |               0
    12 |          0 |               0
    13 |          2 |               0
    14 |          2 |               0
    15 |          2 |               0
    16 |          0 |               0
    17 |          0 |               0
    18 |          1 |               0
    19 |          1 |               0
    20 |          0 |               0
    21 |          1 |               0
    22 |          2 |               0
    23 |          1 |               0
    24 |          1 | 