In [11]:
import os
import time
import copy
import tqdm

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F

from torchvision.models import resnet50, densenet121, mobilenet_v2
from torchvision.datasets import ImageFolder 
from torch.utils.data import DataLoader
from torch.optim import AdamW

from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, auc

# Check GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [12]:
# For model train
def train(model, train_loader, optimizer):
    model.train()  
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device) 
        optimizer.zero_grad() 
        output = model(data)  
        loss = F.cross_entropy(output, target) 
        loss.backward()  
        optimizer.step()  
        
# For model evaluation
def evaluate(model, test_loader):
    model.eval()  
    test_loss = 0 
    correct = 0   
    all_predictions = []
    all_targets = []
    
    with torch.no_grad(): 
        for data, target in test_loader:  
            data, target = data.to(device), target.to(device)  
            output = model(data) 
            
            test_loss += F.cross_entropy(output,target, reduction='sum').item() 
 
            
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item() 
            
            all_predictions.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
   
    test_loss /= len(test_loader.dataset) 
    test_accuracy = 100. * correct / len(test_loader.dataset) 
    
    cm = confusion_matrix(all_targets, all_predictions)
    
    # Obtaining probability values for AUC calculation
    fpr, tpr, _ = roc_curve(all_targets, all_predictions)
    roc_auc = auc(fpr, tpr)

    return test_loss, test_accuracy, cm, roc_auc

In [14]:
# Set up the file directory containing the mtf translation image
data_dir = '../MTF_spl/'

# Set desired hyperparameters
learning_rate = 2*(1e-3)
epoch = 100
batch_size = 512

In [15]:
# You can divide the dataset ratio before training the model. 
# If you want, set it to train and validation dataset

data_transforms = {
    'train': transforms.Compose([transforms.Resize([224,224]), 
        transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(),  
        transforms.RandomCrop(52), transforms.ToTensor(), 
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),
}

# Read input image data through dataloader
image_datasets = {'train': ImageFolder(root=os.path.join(data_dir, 'train'), transform=data_transforms['train'])}
dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True, num_workers=4)}
dataset_sizes = {'train': len(image_datasets['train'])}

class_names = image_datasets['train'].classes

In [16]:
# This code is written on the assumption that there is only a train set based on transfer learning
def train_transfer(model, criterion, optimizer, num_epochs):

    best_model_wts = copy.deepcopy(model.state_dict())  
    best_acc = 0.0  
    
    for epoch in range(num_epochs):
        print('-------------- epoch {} ----------------'.format(epoch+1)) 
        since = time.time()                                     
        phase = 'train' 
        model.train()  
 
        running_loss = 0.0  
        running_corrects = 0  
 
        for inputs, labels in dataloaders[phase]: 
            inputs = inputs.to(device)  
            labels = labels.to(device)   
                
            optimizer.zero_grad() 
                
            with torch.set_grad_enabled(True):  
                outputs = model(inputs)  
                _, preds = torch.max(outputs, 1) 
                loss = criterion(outputs, labels)  
    
                loss.backward()
                optimizer.step()
 
            running_loss += loss.item() * inputs.size(0)  
            running_corrects += torch.sum(preds == labels.data)  
                
        epoch_loss = running_loss / dataset_sizes[phase]  
        epoch_acc = running_corrects.double() / dataset_sizes[phase]  

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        if epoch_acc > best_acc: 
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())
 
        time_elapsed = time.time() - since  
        print('Completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        
    print('Best train Acc: {:4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts) 

    return model

In [None]:
# Select the desired model and set the output of the last layer to 2
model = resnet50(pretrained=True)
model.fc = nn.Linear(2048,2)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=2*(1e-3), weight_decay=1e-6)

In [None]:
# Run the model training as much as you want, save the model
ResNet50_epoch_100 = train_transfer(model, criterion, optimizer, num_epochs=100)
torch.save(ResNet50_epoch_100, 'ResNet50_epoch_100.pt')

### The model can be changed however you want.
### The structure of MoL2 is as follows.

In [None]:
mobilenet = mobilenet_v2(pretrained=True)

feature_extractor = torch.nn.Sequential(*list(mobilenet.children())[:-1])

class MoL2(nn.Module):
    def __init__(self):
        super(MoL2, self).__init__()
        self.features = feature_extractor
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(1280, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, 2)  
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x