In [1]:
import pandas as pd
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data.dataset import random_split


In [16]:
test_imgs_dir = "ondemand/test_feature_maps/"
test_csv_fpath = "test.csv"

In [8]:
label_dict = {'HGSC':0, 'EC':1, 'CC':2, 'LGSC':3, 'MC':4}
revlabel_dict = {v:k for k,v in label_dict.items() }

In [9]:
class ClassificationModel(nn.Module):
    def __init__(self, num_classes):
        super(ClassificationModel, self).__init__()
        self.fc1 = nn.Linear(2048, 64) 
        self.fc2 = nn.Linear(64, 32) 
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.softmax(self.fc3(x), dim = 1)
        return x


In [10]:
model = ClassificationModel(num_classes = 5)
model.load_state_dict(torch.load('model_acc_66aug.pt'))
model.eval()

ClassificationModel(
  (fc1): Linear(in_features=2048, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=5, bias=True)
)

In [28]:
res_predicted = []
res_fnames = []

for imag_fname in os.listdir(test_imgs_dir):

    img_path = os.path.join(test_imgs_dir, str(imag_fname))
    
    image = torch.load(img_path)

    outputs = model(image.unsqueeze(0))
    _, predicted = torch.max(outputs.data, 1)


    res_predicted.append(revlabel_dict[predicted.item()])
    res_fnames.append(imag_fname.split('.')[0])
    

In [29]:
res_df = pd.DataFrame()
res_df['image_id'] = res_fnames
res_df['label'] = res_predicted

res_df.head()

Unnamed: 0,image_id,label
0,19,HGSC
1,48,EC
2,105,HGSC
3,58,HGSC
4,0,HGSC


In [30]:
test_df = pd.read_csv(test_csv_fpath)
test_df

Unnamed: 0,image_id,label
0,0,
1,1,
2,2,
3,3,
4,4,
...,...,...
103,103,
104,104,
105,105,
106,106,


In [31]:
res_df['image_id'] = res_df['image_id'].astype('int')
res_df = pd.merge(test_df.drop(columns = ['label'], axis = 1), res_df, on = 'image_id')

In [32]:
res_df

Unnamed: 0,image_id,label
0,0,HGSC
1,1,HGSC
2,2,HGSC
3,3,HGSC
4,4,CC
...,...,...
103,103,HGSC
104,104,HGSC
105,105,HGSC
106,106,HGSC


In [33]:
res_df.to_csv('test.csv', index = False)