In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
import torchvision.models as models
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import cv2
import random
from collections import Counter
import pydicom
from sklearn.metrics import precision_score,accuracy_score,f1_score,recall_score

In [2]:
label = pd.read_csv(r'/kaggle/input/iaaa-mri-challenge/train.csv')
print(Counter(label['prediction']))

Counter({0: 2741, 1: 391})


In [3]:
def position_finder(pos):
  pos = list(np.round(pos))
  if pos == [1.0 , 0.0 ,0.0 ,0.0, 0.0, -1.0]:
    return 'Coronal'
  elif pos == [0.0 , 1.0 ,0.0 ,0.0, 0.0, -1.0]:
    return 'Sagittal'
  elif pos == [1.0 , 0.0 ,0.0 ,0.0, 1.0, 0.0]:
    return 'Axial'
  else:
    print(pos)
    return 'None'

def Type_Axial_filter(path,df):
    Type = []
    Position = []
    for i in df['SeriesInstanceUID']:
        SerisPath = os.path.join(path,i)
        SamplePath = os.path.join(SerisPath,os.listdir(SerisPath)[0])
        Sample = pydicom.dcmread(SamplePath)
        Position.append(position_finder(Sample.ImageOrientationPatient))
        Type.append(Sample.SeriesDescription)
        
    df['Orientation']= Position
    df['Modality']= Type
    
    return df

label_new = Type_Axial_filter(r'/kaggle/input/iaaa-mri-challenge/data',label)

[1.0, 0.0, 0.0, -0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, -0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, -0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, 0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, 0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, 0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, 0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, 0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, 0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, 0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, 0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, 0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, -0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, -0.0, 1.0, -1.0]
[1.0, 0.0, 0.0, -0.0, 1.0, -1.0]
[1.0, 0.0, -0.0, -0.0, 1.0, -1.0]
[1.0, 0.0, -0.0, -0.0, 1.0, -1.0]
[1.0, 0.0, -0.0, -0.0, 1.0, -1.0]


In [4]:
print('T1: ',Counter(label_new[label_new['Modality']=='T1W_SE']['prediction']))
print('T2: ',Counter(label_new[label_new['Modality']=='T2W_TSE']['prediction']))
print('Flair: ',Counter(label_new[label_new['Modality']=='T2W_FLAIR']['prediction']))
print('no AXIAL: ',Counter(label_new[label_new['Orientation']!='Axial']['prediction']))

T1:  Counter({0: 918, 1: 126})
T2:  Counter({0: 911, 1: 133})
Flair:  Counter({0: 912, 1: 132})
no AXIAL:  Counter({0: 20, 1: 6})


In [5]:
print(Counter(label_new['Modality']))
print(Counter(label_new['Orientation']))

Counter({'T2W_TSE': 1044, 'T1W_SE': 1044, 'T2W_FLAIR': 1044})
Counter({'Axial': 3106, 'None': 18, 'Sagittal': 8})


In [6]:
class MRIDataset(Dataset):
    def __init__(self, root_img, label, transform=None, Interpolation = False, pad=False,
                 Filter_Axial=True,Filter_type =False,balanced = False,n=0):
        
        self.root_img = root_img
        self.transform = transform
        self.Interpolation = Interpolation
        self.pad = pad
        self.label = label
        if Filter_Axial:
            self.label = label[label['Orientation']=='Axial']
        if Filter_type:
            self.label = self.label[self.label['Modality']== Filter_type]
            self.num_anbnormal = self.label[self.label['prediction']==1].shape[0]
        else:
            self.num_anbnormal = self.label[self.label['prediction']==1].shape[0]
            
        if balanced:
            self.label = pd.concat([self.label[self.label['prediction']==1],
                                    self.label[self.label['prediction']==0].iloc[self.num_anbnormal*n:(n+1)*self.num_anbnormal]])
        
        self.image_paths = []
        self.labels = []

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        label = self.label.iloc[idx]['prediction']
        patient_path = os.path.join(self.root_img, self.label.iloc[idx]['SeriesInstanceUID'])
        
        if self.Interpolation:
            images = self.interpolate_slices(self.read_dicom(patient_path),20)
            images = np.array(images,dtype=float)
            images = images / images.max()
        elif self.pad:
            images = np.stack(self.read_dicom(patient_path),axis = 0,dtype=float)
            images = images / images.max()
            images = self.padding(images)
        else:
            images = np.stack(self.read_dicom(patient_path),axis = 0,dtype=float)
            images = images / images.max()
        
        if self.transform:
            images = self.transform(images)
            
            
        images = torch.from_numpy(images)
        
        modality = self.label.iloc[idx]['Modality']
        if modality == 'T1W_SE':
            domain = 0 
        elif modality == 'T2W_TSE':
            domain = 1
        elif modality == 'T2W_FLAIR':
            domain = 2
        
        return images, label, domain
    
    def read_dicom(self,path):
        dicom_files = [pydicom.dcmread(os.path.join(path, f)) for f in os.listdir(path) if f.endswith('.dcm')]
        dicom_files_with_location = []
        
        for dicom_file in dicom_files:
            try:
                location = self.get_slice_location(dicom_file)
                dicom_files_with_location.append((location, dicom_file))
            except ValueError as e:
                print(e)
                
        dicom_files_with_location.sort(key=lambda x: x[0])
        sorted_dicom_files = [cv2.resize(file.pixel_array,(288,288)) for _, file in dicom_files_with_location]
    
        return sorted_dicom_files
                
    def get_slice_location(self,dicom_data):
        """
        Extracts the slice location from a DICOM file.
        """
        # Try to get the SliceLocation attribute first, fallback to ImagePositionPatient if unavailable
        try:
            return dicom_data.SliceLocation
        except AttributeError:
            try:
                return dicom_data.ImagePositionPatient[2]  # Assuming axial slices, z-coordinate
            except AttributeError:
                raise ValueError(f"Cannot determine slice location for file: {dicom_file}")
    
    def interpolate_slices(self,dicom_series, target_num_slices):
        image_data = np.stack(dicom_series, axis=0)
        original_num_slices = image_data.shape[0]
        zoom_factors = [target_num_slices / original_num_slices] + [1] * (image_data.ndim - 1)
        interpolated_data = zoom(image_data, zoom_factors, order=1)  # Linear interpolation
        return interpolated_data
    
    def padding(self,dicoms_file):
        m,n,k = dicoms_file.shape
        pad = np.zeros((20-m,n,k))
        return np.concatenate((dicoms_file,pad),axis=0)
        
imgs_path = r'/kaggle/input/iaaa-mri-challenge/data'
transform = transforms.Compose([
                    transforms.ToTensor(),
                    ])        
# MRTData = MRIDataset(imgs_path,label,transform=False,Interpolation =False,
#                      pad = True,
# #                      Filter_type='T2W_FLAIR',
#                      balanced=True) 
# print('number of Data: ',len(MRTData))
# MRTData[0][1]

In [7]:
epochs = 20        # training epochs
batch_size = 8  
learning_rate = 1e-4


In [8]:
def kfold(dataset):
    Data = DataLoader(dataset, batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    trainloader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True,num_workers=4,pin_memory=True,drop_last=True)
    testloader = DataLoader(test_dataset, batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True)
    return trainloader,testloader
    

In [9]:
folds = [MRIDataset(imgs_path,label,transform=False,Interpolation=False,pad = True,balanced=True,n=i) for i in range(7)]
fold_loder = [kfold(j) for j in folds]
fold_loder

[(<torch.utils.data.dataloader.DataLoader at 0x799327844100>,
  <torch.utils.data.dataloader.DataLoader at 0x799327844490>),
 (<torch.utils.data.dataloader.DataLoader at 0x7993278441c0>,
  <torch.utils.data.dataloader.DataLoader at 0x799327845600>),
 (<torch.utils.data.dataloader.DataLoader at 0x799327845060>,
  <torch.utils.data.dataloader.DataLoader at 0x799327845d50>),
 (<torch.utils.data.dataloader.DataLoader at 0x799327845f60>,
  <torch.utils.data.dataloader.DataLoader at 0x7993278449a0>),
 (<torch.utils.data.dataloader.DataLoader at 0x7993278462c0>,
  <torch.utils.data.dataloader.DataLoader at 0x799327845f00>),
 (<torch.utils.data.dataloader.DataLoader at 0x799327844ee0>,
  <torch.utils.data.dataloader.DataLoader at 0x799327846050>),
 (<torch.utils.data.dataloader.DataLoader at 0x799327844700>,
  <torch.utils.data.dataloader.DataLoader at 0x7993278455a0>)]

In [10]:
final_check_data = MRIDataset(imgs_path,label,transform=False,Interpolation =False,
                     pad = True,
#                      Filter_type='T2W_FLAIR',
                     balanced=False) 
finalcheckloader = DataLoader(final_check_data, batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True)

train_size = int(0.8 * 770)
test_size = 770 - train_size

In [None]:
train_size = int(0.8 * len(MRTData))
test_size = len(MRTData) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(MRTData, [train_size, test_size])
trainloader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True,num_workers=4,pin_memory=True,drop_last=True)
testloader = DataLoader(test_dataset, batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True)
print('number of Train Data: ',train_size)
print('number of Train Data: ',test_size)

In [11]:
from torch.autograd import Function
from torchvision.models.video import r3d_18, R3D_18_Weights


class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None


class CNNModel(nn.Module):

    def __init__(self):
        super(CNNModel, self).__init__()
        
        weights = R3D_18_Weights.DEFAULT
        model = r3d_18(weights=weights)
        self.feature = nn.Sequential(
                        nn.Conv3d(1, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
                               padding=(1, 3, 3), bias=False),
                        *(list(model.children())[1:-1]))
        
        

        

        self.class_classifier = nn.Sequential()
        self.class_classifier.add_module('c_fc1', nn.Linear(512, 100))
        self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(100))
        self.class_classifier.add_module('c_relu1', nn.ReLU(True))
        self.class_classifier.add_module('c_drop1', nn.Dropout())
        self.class_classifier.add_module('c_fc2', nn.Linear(100, 100))
        self.class_classifier.add_module('c_bn2', nn.BatchNorm1d(100))
        self.class_classifier.add_module('c_relu2', nn.ReLU(True))
        self.class_classifier.add_module('c_fc3', nn.Linear(100, 2))
        self.class_classifier.add_module('c_softmax', nn.LogSoftmax(dim=1))

        self.domain_classifier = nn.Sequential()
        self.domain_classifier.add_module('d_fc1', nn.Linear(512, 100))
        self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(100))
        self.domain_classifier.add_module('d_relu1', nn.ReLU(True))
        self.domain_classifier.add_module('d_fc2', nn.Linear(100, 3))
        self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1))

    def forward(self, input_data, alpha):
        input_data = input_data
        
        feature = feature.view(-1,512)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.class_classifier(feature)
        domain_output = self.domain_classifier(reverse_feature)

        return class_output, domain_output
    
    def forward(self, input_data, alpha):
        input_data = input_data.reshape(input_data.data.shape[0],1,20,288, 288)
        feature = self.feature(input_data)
        feature = feature.view(-1, 512)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.class_classifier(feature)
        domain_output = self.domain_classifier(reverse_feature)

        return class_output, domain_output
    
    def feature1(self, input_data, alpha):
        feature = self.feature(input_data)
        feature = feature.view(-1, 512)
        return feature
    
net = CNNModel()
net.cuda()

Downloading: "https://download.pytorch.org/models/r3d_18-b3b3357e.pth" to /root/.cache/torch/hub/checkpoints/r3d_18-b3b3357e.pth
100%|██████████| 127M/127M [00:00<00:00, 167MB/s]  


CNNModel(
  (feature): Sequential(
    (0): Conv3d(1, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
    (1): Sequential(
      (0): BasicBlock(
        (conv1): Sequential(
          (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (conv2): Sequential(
          (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Sequential(
          (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          

In [12]:
# setup optimizer

optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

loss_class = torch.nn.NLLLoss()
loss_domain = torch.nn.NLLLoss()

net = net.cuda()
loss_class = loss_class.cuda()
loss_domain = loss_domain.cuda()

for p in net.parameters():
    p.requires_grad = True

# training
acc_s_list = []
acc_t_list = []
alpha_list =[]
best_accu_t = 0.0
jj = 1
f1_last = 0

for j in fold_loder:
    print(f'fold: {jj}')
    jj+=1
    trainloder,testloader = j
    for epoch in range(epochs):
        net.train()
        for i,data in enumerate(trainloder):

            p = float(i + epoch * train_size) / epochs / train_size
            alpha = 2. / (1. + np.exp(-10 * p)) - 1
            alpha_list.append(alpha)

            # training model using source data
            img, label ,domain= data[0].cuda(),data[1].cuda(),data[2].cuda()

            optimizer.zero_grad()

            class_output, domain_output = net(input_data=img.float(), alpha=alpha)
            err_s_label = loss_class(class_output, label)
            err_s_domain = loss_domain(domain_output, domain)

            err = err_s_domain + err_s_label
            err.backward()
            optimizer.step() 
        
            if i%10 == 0:
                print(f'label loss: {err_s_label} || domain loss {err_s_domain}')
            
        all_y = []
        all_y_pred = []
        test_loss = 0
        net.eval()
        with torch.no_grad():
            for X, y,_ in testloader:
                # distribute data to device
                X, y = X.cuda(), y.cuda().view(-1, )

                output,_ = net(X.float(),alpha=alpha)

                loss = loss_class(output, y)
                test_loss += loss.item()                 # sum up batch loss
                y_pred = output.max(1, keepdim=True)[1]
                all_y.extend(y)
                all_y_pred.extend(y_pred)
            
            all_y = torch.stack(all_y, dim=0)
            all_y_pred = torch.stack(all_y_pred, dim=0)
            f1 = f1_score(all_y.cpu().data.squeeze().numpy(),
                      all_y_pred.cpu().data.squeeze().numpy())
            recall = recall_score(all_y.cpu().data.squeeze().numpy(),
                      all_y_pred.cpu().data.squeeze().numpy())
            precision = precision_score(all_y.cpu().data.squeeze().numpy(),
                      all_y_pred.cpu().data.squeeze().numpy())
        
            print("=====================")
            print(f'EPOCH:{epoch}')
            print(f'Test Loss: {100*test_loss/test_size}')
            print(f'Test f1: {100*f1}')
            print(f'Test Precision: {100*precision}')
            print(f'Test Recall: {100*recall}')
            
        
            if f1>f1_last:
                torch.save(net.state_dict(),r'/kaggle/working/model.pt')
                print('Model Saved')
            
                f1_last =f1
            print("=====================")

fold: 1
label loss: 0.9082282781600952 || domain loss 1.3218164443969727
label loss: 0.5251924395561218 || domain loss 1.0606375932693481
label loss: 0.4704505503177643 || domain loss 0.837471067905426
label loss: 0.5157139897346497 || domain loss 0.9292780160903931
label loss: 0.5832529664039612 || domain loss 0.9847438335418701
label loss: 0.657667338848114 || domain loss 0.9194598197937012
label loss: 0.8527093529701233 || domain loss 0.7984603643417358
label loss: 0.787510097026825 || domain loss 0.7502750754356384
EPOCH:0
Test Loss: 9.155665512208815
Test f1: 60.0
Test Precision: 50.467289719626166
Test Recall: 73.97260273972603
Model Saved
label loss: 0.521632194519043 || domain loss 0.680031418800354
label loss: 0.5162954330444336 || domain loss 1.3752691745758057
label loss: 0.6377092599868774 || domain loss 3.142918109893799
label loss: 0.9797672033309937 || domain loss 3.5193371772766113
label loss: 0.7362856864929199 || domain loss 3.1339266300201416
label loss: 0.6420175433

In [14]:
all_y = []
all_y_pred = []
test_loss = 0
net.eval()
with torch.no_grad():
    for X, y,_ in finalcheckloader:
            # distribute data to device
        X, y = X.cuda(), y.cuda().view(-1, )

        output,_ = net(X.float(),alpha=alpha)

        loss = loss_class(output, y)
        test_loss += loss.item()                 # sum up batch loss
        y_pred = output.max(1, keepdim=True)[1]
        all_y.extend(y)
        all_y_pred.extend(y_pred)
            
    all_y = torch.stack(all_y, dim=0)
    all_y_pred = torch.stack(all_y_pred, dim=0)
    f1 = f1_score(all_y.cpu().data.squeeze().numpy(),
                all_y_pred.cpu().data.squeeze().numpy())
    recall = recall_score(all_y.cpu().data.squeeze().numpy(),
                all_y_pred.cpu().data.squeeze().numpy())
    precision = precision_score(all_y.cpu().data.squeeze().numpy(),
                      all_y_pred.cpu().data.squeeze().numpy())
        
    print("=====================")
    print(f'Loss f1: {100*test_loss/3106}')
    print(f'Test f1: {100*f1}')
    print(f'Test Precision: {100*precision}')
    print(f'Test Recall: {100*recall}')
    print("=====================")

Loss f1: 0.5194718716797285
Test f1: 95.40372670807453
Test Precision: 91.42857142857143
Test Recall: 99.74025974025975


In [15]:
torch.save(net.state_dict(),r'/kaggle/working/model_fianl.pt')

In [16]:
torch.save(net,r'/kaggle/working/model_fianl2.pt')

In [17]:
alpha

0.9998592621337588