In [1]:
import torch
import os
from abc import ABC
from pathlib import Path
from tqdm import tqdm
import inspect
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import DataLoader
import os
from pathlib import Path
import inspect
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim
import torch.optim.lr_scheduler
import matplotlib.pyplot as plt
from typing import Any
import pandas as pd
from sklearn.model_selection import KFold
from model_code.utils import AverageMeter, FocalLossWithWeights, FocalLoss
from model_code.data_loader import data_loader
from model_code.custom_model import CustomRain

device = "cuda" if torch.cuda.is_available() else "cpu"
class Settings:
    number_of_classes = 15
    batch_size = 1
    pretrain = False
    path = "test_images"
    description_path = "test_series_descriptions.csv"
    test_path = "test_images"
    model_path = "RainDrop_0.5374524676799775.pt"
    N_LABELS = 25
    LABELS = ['normal_mild','moderate','severe', 'pred', 'target']
    

settings = Settings()
model = CustomRain(settings.number_of_classes, settings.pretrain)
model.load_state_dict(torch.load(settings.model_path))
model.eval()
model.to(device)

train_labels = r"F:\Projects\Kaggle\RSNA-2024-Lumbar-Spine-Degenerative-Classification\csv\train_series_descriptions_with_paths.csv"
train_path = r"F:\Projects\Kaggle\RSNA-2024-Lumbar-Spine-Degenerative-Classification\train.csv"
train_descriptions = r"F:\Projects\Kaggle\RSNA-2024-Lumbar-Spine-Degenerative-Classification\train_series_descriptions.csv"
train_dataset = data_loader(train_path, train_labels, train_descriptions)
validation_dataset = data_loader(train_path, train_labels, train_descriptions, mode = 'val')
train_df = pd.read_csv(train_path)
primary_labels = train_df['study_id'].values
skf = KFold(n_splits=5, shuffle=True, random_state=42)
autocast = torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16)
scaler = torch.cuda.amp.GradScaler(enabled=True)

weights = torch.tensor([1.0, 2.0, 4.0], dtype= torch.float32)
criterion = nn.CrossEntropyLoss(weight=weights.to(device), ignore_index = -100)

for fold, (train_idx, val_idx) in enumerate(skf.split(range(len(primary_labels)))):
    print(f"Fold: {fold + 1}")
    best_valid_loss = np.inf
    early_stopping_counter = 0
    train_subset = torch.utils.data.Subset(train_dataset, train_idx)
    val_subset = torch.utils.data.Subset(train_dataset, val_idx)
    validation_subset = torch.utils.data.Subset(validation_dataset, val_idx)
    break

def reorder_labels(labels):
    
    # Create an empty tensor to hold the reversed labels
    original_labels = torch.empty_like(labels)

    # Number of total sets
    num_sets = len(labels) // 3

    # Reversing the order based on % 5 position
    for set_index in range(num_sets):
        source_start = set_index * 3
        # Calculate destination start based on the modulus operation
        dest_start = ((set_index % 5) * 15) + (set_index // 5 * 3)
        original_labels[dest_start:dest_start + 3] = labels[source_start:source_start + 3]
    return original_labels

CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]


submissions = pd.DataFrame()
row_names = []
y_preds = []
model.eval()
with torch.no_grad() and autocast:
    losses = AverageMeter() 
    progress_bar = tqdm(validation_subset, desc='Validation', leave=False)
    for (study_id, sagittal_T2_l1_l2, axial_l1_l2, sagittal_T2_l2_l3,
            axial_l2_l3, sagittal_T2_l3_l4, axial_l3_l4, sagittal_T2_l4_l5, axial_l4_l5, sagittal_T2_l5_s1,
            axial_l5_s1, reordered_labels) in progress_bar:

        # sagittal_stack = sagittal_stack.cuda()
        sagittal_l1_l2 = sagittal_T2_l1_l2.cuda().unsqueeze(0)
        sagittal_l2_l3 = sagittal_T2_l2_l3.cuda().unsqueeze(0)
        sagittal_l3_l4 = sagittal_T2_l3_l4.cuda().unsqueeze(0)
        sagittal_l4_l5 = sagittal_T2_l4_l5.cuda().unsqueeze(0)
        sagittal_l5_s1 = sagittal_T2_l5_s1.cuda().unsqueeze(0)
        # labels = reordered_labels.cuda().to(torch.long).unsqueeze(0)
        loss_dis = 0.0
        loss_total = 0.0
        
        output1 = model(sagittal_l1_l2, torch.tensor(0, device=device)).squeeze()
        output2 = model(sagittal_l2_l3, torch.tensor(1, device=device)).squeeze()
        output3 = model(sagittal_l3_l4, torch.tensor(2, device=device)).squeeze()
        output4 = model(sagittal_l4_l5, torch.tensor(3, device=device)).squeeze()
        output5 = model(sagittal_l5_s1, torch.tensor(4, device=device)).squeeze()
        output = torch.cat([output1, output2, output3, output4, output5], dim=0)
        output = reorder_labels(output)
        pred_per_study = np.zeros((25, 5))

        for cond in CONDITIONS:
            for level in LEVELS:
                row_names.append(str(study_id) + '_' + cond + '_' + level)

        for col in range(settings.N_LABELS):
            pred = output[col*3:col*3+3]
            y_pred = pred.detach().float().softmax(0).cpu().numpy()
            pred = pred.argmax().item()
            target = reordered_labels[col].numpy().tolist()
            pred_per_study[col] += [y_pred[0], y_pred[1], y_pred[2], pred, target]
        y_preds.append(pred_per_study)
y_preds = np.concatenate(y_preds, axis=0)

submissions['row_id'] = row_names
submissions[settings.LABELS] = y_preds
submissions


Fold: 1


Validation:   4%|▍         | 16/395 [01:17<26:56,  4.26s/it]

In [2]:
submissions.to_csv('submission_check.csv', index=False)

In [3]:
count_0 = 0
count_1 = 0
count_2 = 0
count_3 = 0
count_4 = 0
count_5 = 0
count_6 = 0
count_7 = 0
count_8 = 0
count_9 = 0
count_10 = 0
count_11 = 0
count_12 = 0
count_13 = 0
count_14 = 0
for index, row in submissions.iterrows():
    if index % 15 == 0:
        if row["pred"] != row["target"]:
            count_0 += 1
    if index % 15 == 1:
        if row["pred"] != row["target"]:
            count_1 += 1
    if index % 15 == 2:
        if row["pred"] != row["target"]:
            count_2 += 1
    if index % 15 == 3:
        if row["pred"] != row["target"]:
            count_3 += 1
    if index % 15 == 4:
        if row["pred"] != row["target"]:
            count_4 += 1
    if index % 15 == 5:
        if row["pred"] != row["target"]:
            count_5 += 1
    if index % 15 == 6:
        if row["pred"] != row["target"]:
            count_6 += 1
    if index % 15 == 7:
        if row["pred"] != row["target"]:
            count_7 += 1
    if index % 15 == 8:
        if row["pred"] != row["target"]:
            count_8 += 1
    if index % 15 == 9:
        if row["pred"] != row["target"]:
            count_9 += 1
    if index % 15 == 10:
        if row["pred"] != row["target"]:
            count_10 += 1
    if index % 15 == 11:
        if row["pred"] != row["target"]:
            count_11 += 1
    if index % 15 == 12:
        if row["pred"] != row["target"]:
            count_12 += 1
    if index % 15 == 13:
        if row["pred"] != row["target"]:
            count_13 += 1
    if index % 15 == 14:
        if row["pred"] != row["target"]:
            count_14 += 1

print(count_0, count_1, count_2, count_3, count_4, count_5, count_6, count_7, count_8, count_9, count_10, count_11, count_12, count_13, count_14)
    


108 163 195 273 248 82 167 191 263 228 90 174 201 263 255


109 173 215 304 256 76 175 193 271 243 93 175 215 290 258

In [4]:
1e9

1000000000.0

In [5]:
labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
def _reorder_of_labels(labels):
        # Create a list to hold the reordered labels
        # reordered_labels = [labels[0], labels[5], labels[10], labels[15], labels[20],
        #                     labels[1], labels[6], labels[11], labels[16], labels[21],
        #                     labels[2], labels[7], labels[12], labels[17], labels[22],
        #                     labels[3], labels[8], labels[13], labels[18], labels[23],
        #                     labels[4], labels[9], labels[14], labels[19], labels[24]]

        reordered_labels = []
        # Loop over the remainders from 0 to 4
        for i in range(5):
            # Add labels to reordered_labels based on the remainder when index is divided by 5
            reordered_labels.extend([label for idx, label in enumerate(labels) if idx % 5 == i])
        return reordered_labels
_reorder_of_labels(labels)

[0,
 5,
 10,
 15,
 20,
 1,
 6,
 11,
 16,
 21,
 2,
 7,
 12,
 17,
 22,
 3,
 8,
 13,
 18,
 23,
 4,
 9,
 14,
 19,
 24]