<a href="https://colab.research.google.com/github/ikmtd/Classification_Demonstration/blob/main/Classification_Demonstration_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#Google Drive Mount
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#import Modules
import torch
from torch import nn, optim
import torchvision
from torchvision import transforms, utils, datasets, models
from torch.utils.data import (Dataset, DataLoader, TensorDataset)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import csv

In [3]:
print(torch.__version__)

1.13.0+cu116


In [78]:
#　Move to the location where the Classification_Demonstration folder is saved.
cd '/content/drive/My Drive/'

/content/drive/My Drive


In [79]:
#test dataset
data_dir = 'Classification_Demonstration/data/test_data/'
#wt -/- TK6
wt_dir = data_dir + 'wt/'
#ctf18 -/- TK6
ctf18_dir = data_dir + 'ctf18/'

#resize
size = (224, 224)

data_transforms = {
        'test' : transforms.Compose([
          transforms.Resize(size),
          transforms.ToTensor(),
          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
          ]),

        'display' : transforms.Compose([
          transforms.Resize(size),
          ]),



}

#test dataset
wt_dataset = datasets.ImageFolder(wt_dir, transform=data_transforms['test'])
ctf18_dataset = datasets.ImageFolder(ctf18_dir, transform=data_transforms['test'])

#test dataset for display
wt_dataset_display = datasets.ImageFolder(wt_dir, transform=data_transforms['display'])
ctf18_dataset_display = datasets.ImageFolder(ctf18_dir, transform=data_transforms['display'])

dataloaders = {
        'wt' : torch.utils.data.DataLoader(wt_dataset, batch_size= 64),
        'ctf18' : torch.utils.data.DataLoader(ctf18_dataset, batch_size= 64)
}

In [7]:
#Correspondence between labels and indexes
# typeA: 0, typeB: 1, typeC: 2 
print(wt_dataset.class_to_idx)

{'a': 0, 'b': 1, 'c': 2}


In [8]:
#Counting predicted number of images in dataset

def pred_count(ypreds, class_num):
  count_len = []
  ypreds = ypreds.to('cpu').detach().numpy().copy()
  ypreds = ypreds.tolist()
  for i in range(class_num):
    count = ypreds.count(i)
    count_len.append(count)
  return count_len

In [67]:
# Define Test Network
def test_net(csv_name, net, test_loader, device="cpu"):
  net.eval()
  ys = []
  ypreds = []
  #class number
  class_num = 3
  for x, y in test_loader:
    x = x.to(device)
    y = y.to(device)
    with torch.no_grad():
      y_net = net(x)
      _, y_pred = nn.functional.softmax(y_net, dim=1).max(1)
    ys.append(y)
    ypreds.append(y_pred)
  # ys : Labels for all test data.
  ys = torch.cat(ys)
  # ypreds : Prediction results for all test data.
  ypreds = torch.cat(ypreds)
  acc = (ys == ypreds).float().sum() / len(ys)
  pred_sum = pred_count(ypreds, class_num)
  ypreds = pd.DataFrame(ypreds.to("cpu"))
  ypreds.to_csv('Classification_Demonstration/results/'+str(csv_name)+'.csv', index=False, header=False)
  data_num = sum(pred_sum)
  pred_percent = (round(pred_sum[0]/data_num, 2), round(pred_sum[1]/data_num, 2), round(pred_sum[2]/data_num, 2))
  return [acc.item(), pred_percent]

In [10]:
# Both ResNet and SqueezeNet can be selected from 1~30. 
net_num = 1

In [None]:
#SqueezeNet
squeezenet = models.squeezenet1_1(pretrained=True)
for param in squeezenet.parameters():
    param.requires_grad = False
squeezenet.classifier[1] = nn.Conv2d(512, 3, kernel_size=(1,1), stride=(1,1))

squeezenet.load_state_dict(torch.load('Classification_Demonstration/models/SqueezeNet/SqueezeNet_' + str(net_num) + '.pth'))
squeezenet.to("cuda:0")

# wt cells chromosomes prediction
sq_acc_wt, sq_pred_wt = (test_net('wt_SqueezeNet', squeezenet, dataloaders['wt'], device = "cuda:0")) # The first argument of the test_net() is csv file name.
print( "wt")
print("Concordande Rates : ", str(sq_acc_wt))
print("Predicted percentage")
print('typeA', sq_pred_wt[0], 'typeB', sq_pred_wt[1], 'typeC', sq_pred_wt[2])
print()

# ctf-18 cells chromosomes prediction
sq_acc_ctf18, sq_pred_ctf18 = (test_net('ctf18_SqueezeNet', squeezenet, dataloaders['ctf18'], device = "cuda:0"))
print( "ctf18")
print("Concordande Rates : ", str(sq_acc_ctf18))
print("Predicted percentage")
print('typeA', sq_pred_ctf18[0], 'typeB', sq_pred_ctf18[1], 'typeC', sq_pred_ctf18[2])

In [None]:
#ResNet-18
resnet = models.resnet18(pretrained=True)
for p in resnet.parameters():
   p.requires_grad=False
fc_input_dim  = resnet.fc.in_features
resnet.fc = nn.Linear(fc_input_dim, 3)
resnet.to("cuda:0")

resnet.load_state_dict(torch.load(path +  'Classification_Demonstration/models/ResNet-18/ResNet18_' + str(net_num) + '.pth'))
resnet.to("cuda:0")

# wt cells chromosomes prediction
res_acc_wt, res_pred_wt = (test_net('wt_ResNet-18', resnet, dataloaders['wt'], device = "cuda:0")) # The first argument of the test_net() is csv file name.
print( "wt")
print("Concordande Rates : ", str(res_acc_wt))
print("Predicted percentage")
print('typeA', res_pred_wt[0], 'typeB', res_pred_wt[1], 'typeC', res_pred_wt[2])
print()

# ctf-18 cells chromosomes prediction
res_acc_ctf18, res_pred_ctf18 = (test_net('ctf18_ResNet-18', resnet, dataloaders['ctf18'], device = "cuda:0"))
print( "ctf18")
print("Concordande Rates : ", str(res_acc_ctf18))
print("Predicted percentage")
print('typeA', res_pred_ctf18[0], 'typeB', res_pred_ctf18[1], 'typeC', res_pred_ctf18[2])

# The following codes can be executed to match the image, label, and prediction results.

In [51]:
def idx_to_class(num):
  if num == 0:
    return 'typeA'
  elif num == 1:
    return 'typeB'
  elif num == 2:
    return 'typeC'

In [None]:
# Change the name to whatever you decide in csv_filename.
csv_filename = 'wt_SqueezeNet.csv'

In [None]:
pred_class = []
with open('Classification_Demonstration/results/' + str(csv_filename) + '.csv') as f:
  reader = csv.reader(f)

  for row in reader:
      pred_class.append(int(row[0]))

# Note that pred_class is a large value. You can adjust it to any number. 
for i in range(len(pred_class)):
  image, label = wt_dataset_display[i]
  plt.imshow(image)
  plt.show()
  print('label : ', idx_to_class(label), 'predict : ', idx_to_class(pred_class[i]))