In [3]:
import pandas as pd
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data.dataset import random_split


In [11]:
train_csv_path = "../ondemand/train.csv"

In [12]:
train_imgs_dir = "../ondemand/train_feature_maps_aug/"

In [13]:
label_dict = {'HGSC':0, 'EC':1, 'CC':2, 'LGSC':3, 'MC':4}

In [14]:
revlabel_dict = {v:k for k,v in label_dict.items() }

In [16]:
class OvarianCancerFtrDataset(Dataset):

    def __init__(self, csv_file, root_dir, label_dict, transform=None):

        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.label_dict = label_dict
        self.image_fnames = os.listdir(root_dir)

    def __len__(self):
        return len(self.image_fnames)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = os.path.join(self.root_dir, str(self.image_fnames[idx]))
        image = torch.load(img_path)
        label = self.df.loc[self.df['image_id'] == int(self.image_fnames[idx].split('.')[0].split('_')[0])]['label'].iloc[0]
        label = self.label_dict[label]

        if self.transform:
            image = self.transform(image)

        return image, label

In [17]:
dataset = OvarianCancerFtrDataset(train_csv_path, train_imgs_dir, label_dict)

In [20]:
train_size = int(0.8 * len(dataset))  
val_size = len(dataset) - train_size  
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)


In [63]:
class ClassificationModel(nn.Module):
    def __init__(self, num_classes):
        super(ClassificationModel, self).__init__()
        self.fc1 = nn.Linear(2048, 64) 
        self.fc2 = nn.Linear(64, 32) 
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.softmax(self.fc3(x), dim = 1)
        return x


In [93]:
model = ClassificationModel(num_classes = 5)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)  


In [94]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = torch.device("cpu")

model = model.to(device)


In [95]:
num_epochs = 50

for epoch in range(num_epochs):
    
    model.train()
    total_train_loss = 0

    for batch in tqdm(train_loader):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device) 

        outputs = model(inputs.float())
        loss = criterion(outputs, labels.long())  

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    model.eval() 
    total_val_loss = 0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad(): 
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device) 

            outputs = model(inputs.float())
            loss = criterion(outputs, labels.long())

            total_val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total_predictions += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = correct_predictions / total_predictions

    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')


100%|███████████████████████████████████████████████████████████████| 860/860 [00:02<00:00, 424.94it/s]


Epoch [1/50], Training Loss: 1.4995, Validation Loss: 1.4770, Validation Accuracy: 0.4163


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 927.84it/s]


Epoch [2/50], Training Loss: 1.4719, Validation Loss: 1.4484, Validation Accuracy: 0.4500


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 944.67it/s]


Epoch [3/50], Training Loss: 1.4457, Validation Loss: 1.4386, Validation Accuracy: 0.4872


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 948.56it/s]


Epoch [4/50], Training Loss: 1.4286, Validation Loss: 1.4140, Validation Accuracy: 0.4872


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 951.15it/s]


Epoch [5/50], Training Loss: 1.4153, Validation Loss: 1.4017, Validation Accuracy: 0.4953


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 943.77it/s]


Epoch [6/50], Training Loss: 1.4017, Validation Loss: 1.3862, Validation Accuracy: 0.5372


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 947.41it/s]


Epoch [7/50], Training Loss: 1.3860, Validation Loss: 1.3694, Validation Accuracy: 0.5477


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 944.38it/s]


Epoch [8/50], Training Loss: 1.3744, Validation Loss: 1.3543, Validation Accuracy: 0.5674


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 941.41it/s]


Epoch [9/50], Training Loss: 1.3628, Validation Loss: 1.3458, Validation Accuracy: 0.5605


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 939.03it/s]


Epoch [10/50], Training Loss: 1.3520, Validation Loss: 1.3407, Validation Accuracy: 0.5930


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 934.87it/s]


Epoch [11/50], Training Loss: 1.3452, Validation Loss: 1.3278, Validation Accuracy: 0.5837


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 914.06it/s]


Epoch [12/50], Training Loss: 1.3364, Validation Loss: 1.3202, Validation Accuracy: 0.6000


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 912.97it/s]


Epoch [13/50], Training Loss: 1.3311, Validation Loss: 1.3304, Validation Accuracy: 0.5616


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 932.44it/s]


Epoch [14/50], Training Loss: 1.3256, Validation Loss: 1.3173, Validation Accuracy: 0.5942


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 929.58it/s]


Epoch [15/50], Training Loss: 1.3196, Validation Loss: 1.3384, Validation Accuracy: 0.5535


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 931.94it/s]


Epoch [16/50], Training Loss: 1.3146, Validation Loss: 1.3017, Validation Accuracy: 0.6093


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 934.33it/s]


Epoch [17/50], Training Loss: 1.3097, Validation Loss: 1.3455, Validation Accuracy: 0.5616


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 926.80it/s]


Epoch [18/50], Training Loss: 1.3044, Validation Loss: 1.2954, Validation Accuracy: 0.6291


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 915.52it/s]


Epoch [19/50], Training Loss: 1.3024, Validation Loss: 1.3056, Validation Accuracy: 0.5977


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 935.40it/s]


Epoch [20/50], Training Loss: 1.2985, Validation Loss: 1.2847, Validation Accuracy: 0.6267


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 934.65it/s]


Epoch [21/50], Training Loss: 1.2926, Validation Loss: 1.2935, Validation Accuracy: 0.6186


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 928.21it/s]


Epoch [22/50], Training Loss: 1.2883, Validation Loss: 1.2874, Validation Accuracy: 0.6128


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 933.52it/s]


Epoch [23/50], Training Loss: 1.2871, Validation Loss: 1.2836, Validation Accuracy: 0.6174


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 934.94it/s]


Epoch [24/50], Training Loss: 1.2844, Validation Loss: 1.2739, Validation Accuracy: 0.6430


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 930.04it/s]


Epoch [25/50], Training Loss: 1.2830, Validation Loss: 1.2718, Validation Accuracy: 0.6430


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 932.30it/s]


Epoch [26/50], Training Loss: 1.2785, Validation Loss: 1.2727, Validation Accuracy: 0.6442


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 928.31it/s]


Epoch [27/50], Training Loss: 1.2777, Validation Loss: 1.2740, Validation Accuracy: 0.6419


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 929.45it/s]


Epoch [28/50], Training Loss: 1.2756, Validation Loss: 1.2687, Validation Accuracy: 0.6407


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 926.40it/s]


Epoch [29/50], Training Loss: 1.2701, Validation Loss: 1.3037, Validation Accuracy: 0.5965


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 932.93it/s]


Epoch [30/50], Training Loss: 1.2692, Validation Loss: 1.2621, Validation Accuracy: 0.6442


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 930.35it/s]


Epoch [31/50], Training Loss: 1.2682, Validation Loss: 1.2672, Validation Accuracy: 0.6477


100%|███████████████████████████████████████████████████████████████| 860/860 [00:01<00:00, 854.51it/s]


Epoch [32/50], Training Loss: 1.2685, Validation Loss: 1.2629, Validation Accuracy: 0.6488


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 936.99it/s]


Epoch [33/50], Training Loss: 1.2639, Validation Loss: 1.2578, Validation Accuracy: 0.6500


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 938.31it/s]


Epoch [34/50], Training Loss: 1.2617, Validation Loss: 1.2545, Validation Accuracy: 0.6547


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 934.70it/s]


Epoch [35/50], Training Loss: 1.2591, Validation Loss: 1.3134, Validation Accuracy: 0.5802


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 931.80it/s]


Epoch [36/50], Training Loss: 1.2562, Validation Loss: 1.2610, Validation Accuracy: 0.6488


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 927.02it/s]


Epoch [37/50], Training Loss: 1.2543, Validation Loss: 1.2872, Validation Accuracy: 0.6163


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 920.29it/s]


Epoch [38/50], Training Loss: 1.2524, Validation Loss: 1.2598, Validation Accuracy: 0.6547


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 912.73it/s]


Epoch [39/50], Training Loss: 1.2500, Validation Loss: 1.2563, Validation Accuracy: 0.6593


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 930.78it/s]


Epoch [40/50], Training Loss: 1.2501, Validation Loss: 1.2567, Validation Accuracy: 0.6512


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 930.70it/s]


Epoch [41/50], Training Loss: 1.2475, Validation Loss: 1.2505, Validation Accuracy: 0.6581


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 930.47it/s]


Epoch [42/50], Training Loss: 1.2461, Validation Loss: 1.2453, Validation Accuracy: 0.6686


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 924.31it/s]


Epoch [43/50], Training Loss: 1.2460, Validation Loss: 1.2546, Validation Accuracy: 0.6465


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 929.70it/s]


Epoch [44/50], Training Loss: 1.2430, Validation Loss: 1.2528, Validation Accuracy: 0.6581


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 923.68it/s]


Epoch [45/50], Training Loss: 1.2406, Validation Loss: 1.2496, Validation Accuracy: 0.6605


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 931.87it/s]


Epoch [46/50], Training Loss: 1.2405, Validation Loss: 1.2416, Validation Accuracy: 0.6686


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 935.25it/s]


Epoch [47/50], Training Loss: 1.2395, Validation Loss: 1.2437, Validation Accuracy: 0.6663


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 934.20it/s]


Epoch [48/50], Training Loss: 1.2386, Validation Loss: 1.2452, Validation Accuracy: 0.6663


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 911.78it/s]


Epoch [49/50], Training Loss: 1.2369, Validation Loss: 1.2557, Validation Accuracy: 0.6500


100%|███████████████████████████████████████████████████████████████| 860/860 [00:00<00:00, 927.06it/s]


Epoch [50/50], Training Loss: 1.2330, Validation Loss: 1.2396, Validation Accuracy: 0.6733


In [28]:
torch.save(model.state_dict(), 'model_acc_67aug.pt')
