In [1]:
dataset = 'BreakHis1'
method = 'MOON'
i = 4
learningRate = 0.001
epoch = 10
mu = 10
tau = 0.5
train_dir = '/DATA1/Mangaldeep/V3/Dataset/BreakHis1/train'
val_dir = '/DATA1/Mangaldeep/V3/Dataset/BreakHis1/val'
test_dir = '/DATA1/Mangaldeep/V3/Dataset/BreakHis1/test'
base_weight_dir = '/DATA1/Mangaldeep/V3/Weight/MOON'
base_log_dir = '/DATA1/Mangaldeep/V3/Log/MOON'
avg_weight_dir = base_weight_dir+'/avg'+str(i)+'.pth'
logfilepath = base_log_dir+'/'+str(method)+'_'+str(dataset)+'_'+str(i)+'.txt'

In [2]:
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import os
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torchvision.transforms import ToTensor,Resize,Normalize,RandomHorizontalFlip,RandomVerticalFlip,RandomCrop,CenterCrop
from sklearn.metrics import accuracy_score,precision_recall_fscore_support,roc_curve,auc,roc_auc_score,classification_report
from torch.nn.functional import cosine_similarity
import matplotlib.pyplot as plt


cudnn.benchmark = True
plt.ion()   # interactive mode

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [3]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        Resize((256,256),interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.RandomCrop(224),
        #transforms.RandomResizedCrop(size = (224,224)),
        #transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        #transforms.Resize(256),
        Resize((256,256),interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

#data_dir = 'data/hymenoptera_data'
train_dataset = datasets.ImageFolder(train_dir,data_transforms['train'])
test_dataset = datasets.ImageFolder(val_dir,data_transforms['val'])

# define a loader for the training data we can iterate through in 32-image batches
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=40,
        num_workers=6,
        shuffle=True
    )

# define a loader for the testing data we can iterate through in 32-image batches
test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=40,
        num_workers=4,
        shuffle=True
    )

dataset_sizes = {'train': len(train_dataset) ,'val': len(test_dataset)}
class_names = train_dataset.classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
sample_size = len(train_dataset)
with open (logfilepath, 'a') as file:
    file.writelines(f"{method} training started for {dataset} dataset")
    file.writelines('\n')
    file.writelines(f"{i}th round learning rate {learningRate} mu value {mu}")
    file.writelines('\n')
    file.writelines(f"Train sample_size: {len(train_dataset)} ,validation: {len(test_dataset)}")
    file.writelines('\n')

In [4]:
##Global model and previous model
if i >0 :
    global_model_path = base_weight_dir+'/avg'+str(i-1)+'.pth'
    prev_model_path = base_weight_dir+'/' +str(method)+'_'+str(dataset)+'_'+str(i-1)+'_'+str(sample_size)+'.pth'
else:
    global_model_path = base_weight_dir+'/avg'+str(0)+'.pth'
    prev_model_path = base_weight_dir+'/avg'+str(0)+'.pth'

temp_model = models.efficientnet_b3(weights=None)
num_ftrs = temp_model.classifier[1].in_features
temp_model.classifier[1] = nn.Linear(num_ftrs, 2)    
global_model = copy.deepcopy(temp_model)
prev_model = copy.deepcopy(temp_model)

global_model.load_state_dict(torch.load(global_model_path))
prev_model.load_state_dict(torch.load(prev_model_path))


global_model = global_model.to(device)
prev_model = prev_model.to(device)

In [5]:
'''#NEw code
from torchvision.models.feature_extraction import create_feature_extractor
def small_model(model):
    return_modes = {"features.8": "features.8"}
    model2 = create_feature_extractor(model,return_nodes=return_modes)
    return model2

def get_final_features(model, x, detach=True):
    func = (lambda x: x.clone().detach()) if detach else (lambda x: x)
    x = torch.flatten(x, start_dim=1)
    model2 = small_model(model)
    op = model2(x)
    print (op.shape)
    return func(x)'''

'#NEw code\nfrom torchvision.models.feature_extraction import create_feature_extractor\ndef small_model(model):\n    return_modes = {"features.8": "features.8"}\n    model2 = create_feature_extractor(model,return_nodes=return_modes)\n    return model2\n\ndef get_final_features(model, x, detach=True):\n    func = (lambda x: x.clone().detach()) if detach else (lambda x: x)\n    x = torch.flatten(x, start_dim=1)\n    model2 = small_model(model)\n    op = model2(x)\n    print (op.shape)\n    return func(x)'

In [6]:
def train_model(model,global_model,prev_model, criterion, optimizer, scheduler, num_epochs=25,mu=10,tau = 0.5):
    since = time.time()
    #change1
    train_err = []
    val_err = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_valLoss = 0.0
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            if phase == 'train' :
              dataloaders = train_loader
            else :
              dataloaders = test_loader
            # Iterate over data.
            for inputs, labels in dataloaders:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                  z_curr = model(inputs)
                  #print(z_curr.grad_fn)
                  z_global = global_model(inputs).detach()#.clone().detach()#.cpu().numpy()
                  #z_global = torch.tensor(z_global,requires_grad=True).to(device)
                  #print(z_curr.grad_fn)
                  z_prev = prev_model(inputs).detach()#.clone().detach()#.cpu().numpy()
                  #print(z_prev.grad_fn)
                  #z_prev = torch.tensor(z_prev,requires_grad=True).to(device)
                  logit = model(inputs)
                  _, preds = torch.max(logit, 1)
                  loss_sup = criterion(logit, labels)
                  loss_con = -torch.log(
                    torch.exp(
                        cosine_similarity(z_curr.flatten(1), z_global.flatten(1))
                        / tau)/ (
                        torch.exp(
                            cosine_similarity(z_prev.flatten(1), z_curr.flatten(1))
                            / tau)
                        + torch.exp(
                            cosine_similarity(z_curr.flatten(1), z_global.flatten(1))
                            / tau
                        )))
                  loss = loss_sup + mu * torch.mean(loss_con)
                  if phase == 'train':
                    #optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()    
                  #loss_con.requires_grad = True

                  
                  
                    

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            if phase == 'train':
              train_err.append(epoch_loss)
            else :
              val_err.append(epoch_loss)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            with open (logfilepath, 'a') as file:
              file.writelines(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
              file.writelines('\n')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                best_valLoss = epoch_loss

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f} Validation loss : {best_valLoss:4f}')
    with open (logfilepath, 'a') as file:
              file.writelines(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
              file.writelines('\n')
              file.writelines(f'Best val Acc: {best_acc:4f} Validation loss : {best_valLoss:4f}')
              file.writelines('\n')
              file.writelines(f'Training error = {train_err}')
              file.writelines('\n')
              file.writelines(f'Validation error = {val_err}')
              file.writelines('\n')


    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [7]:
model_ft = models.efficientnet_b3(weights=None)
num_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(num_ftrs, 2)

if avg_weight_dir != '':
  model_ft.load_state_dict(torch.load(avg_weight_dir,map_location = 'cpu'))


model_ft = model_ft.to(device)
model_ft.train()
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
#optimizer_ft = optim.SGD(model_ft.parameters(), lr=learningRate, momentum=0.9)
optimizer_ft = optim.Adam(model_ft.parameters(), lr=learningRate)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.5)

In [8]:
model_ft = train_model(model_ft,global_model,prev_model, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=epoch,mu=mu)

Epoch 0/9
----------


train Loss: 7.5895 Acc: 0.5481
val Loss: 3.8572 Acc: 0.3333

Epoch 1/9
----------
train Loss: 3.5994 Acc: 0.3615
val Loss: 2.0006 Acc: 0.3131

Epoch 2/9
----------
train Loss: 2.1798 Acc: 0.3372
val Loss: 1.5573 Acc: 0.3131

Epoch 3/9
----------
train Loss: 1.8073 Acc: 0.3529
val Loss: 1.8947 Acc: 0.3131

Epoch 4/9
----------
train Loss: 1.7880 Acc: 0.3443
val Loss: 1.7263 Acc: 0.3232

Epoch 5/9
----------
train Loss: 1.7056 Acc: 0.3587
val Loss: 2.2309 Acc: 0.2626

Epoch 6/9
----------
train Loss: 1.9656 Acc: 0.3429
val Loss: 1.7085 Acc: 0.3131

Epoch 7/9
----------
train Loss: 1.8856 Acc: 0.3902
val Loss: 1.4621 Acc: 0.3333

Epoch 8/9
----------
train Loss: 1.7519 Acc: 0.3630
val Loss: 1.7393 Acc: 0.3232

Epoch 9/9
----------
train Loss: 1.7639 Acc: 0.3501
val Loss: 1.5295 Acc: 0.3434

Training complete in 1m 13s
Best val Acc: 0.343434 Validation loss : 1.529516


In [9]:
for name, param in model_ft.named_parameters():
    if not param.requires_grad:
        print(f"Parameter {name} does not require gradients.")


In [10]:
from torchvision.models.feature_extraction import get_graph_node_names
print(get_graph_node_names(model_ft))

(['x', 'features.0', 'features.1.0.block.0', 'features.1.0.block.1', 'features.1.0.block.2', 'features.1.1.block.0', 'features.1.1.block.1', 'features.1.1.block.2', 'features.1.1.stochastic_depth', 'features.1.1.add', 'features.2.0.block.0', 'features.2.0.block.1', 'features.2.0.block.2', 'features.2.0.block.3', 'features.2.1.block.0', 'features.2.1.block.1', 'features.2.1.block.2', 'features.2.1.block.3', 'features.2.1.stochastic_depth', 'features.2.1.add', 'features.2.2.block.0', 'features.2.2.block.1', 'features.2.2.block.2', 'features.2.2.block.3', 'features.2.2.stochastic_depth', 'features.2.2.add', 'features.3.0.block.0', 'features.3.0.block.1', 'features.3.0.block.2', 'features.3.0.block.3', 'features.3.1.block.0', 'features.3.1.block.1', 'features.3.1.block.2', 'features.3.1.block.3', 'features.3.1.stochastic_depth', 'features.3.1.add', 'features.3.2.block.0', 'features.3.2.block.1', 'features.3.2.block.2', 'features.3.2.block.3', 'features.3.2.stochastic_depth', 'features.3.2.

In [11]:
model_weight_dir = base_weight_dir+'/' +str(method)+'_'+str(dataset)+'_'+str(i)+'_'+str(sample_size)+'.pth'
torch.save(model_ft.state_dict(), model_weight_dir)

In [12]:
test_dataset = datasets.ImageFolder(test_dir, transform=data_transforms['val'])
class_names = test_dataset.classes
test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        num_workers=0,
        shuffle=False
    )
model_ft.eval()
with torch.no_grad():
  y_true=[]
  y_pred=[]
  y_probas = []
  for i, (inputs, labels) in enumerate(test_loader):
    inputs = inputs.to(device)
    labels = labels.to(device)
    y_true.append(labels.detach().cpu().numpy())
    model_ft.to(device)
    outputs = model_ft(inputs)
    outputs = nn.Softmax(dim=1)(outputs)
    max, preds = torch.max(outputs, 1)
    y_probas.append(outputs.cpu().numpy())
    y_pred.append(int(preds.detach().cpu().numpy()))


y_true = np.array(y_true)
y_pred = np.array(y_pred)
prob = np.array(y_probas)
acc = accuracy_score(y_true,y_pred)
y_probas = prob.reshape(prob.shape[0],prob.shape[2])
precision,recall,fscore,_=precision_recall_fscore_support(y_true, y_pred, average='macro')
auc_score = roc_auc_score(y_true, y_probas[:,1])
precision,recall,fscore,_=precision_recall_fscore_support(y_true, y_pred, average='macro')
print(f"accuracy = {round(acc,2)} , precision = {round(precision,2)},recall = {round(recall,2)},fscore = {round(fscore,2)},,auc_score = {round(auc_score,2)}")
print(classification_report(y_true,y_pred))
with open (logfilepath, 'a') as file:
    file.writelines(f"accuracy = {round(acc,2)} , precision = {round(precision,2)},recall = {round(recall,2)},fscore = {round(fscore,2)},,auc_score = {round(auc_score,2)}")
    file.writelines('\n')
    file.writelines(classification_report(y_true,y_pred))
    file.writelines('\n')

accuracy = 0.4 , precision = 0.56,recall = 0.54,fscore = 0.39,,auc_score = 0.47
              precision    recall  f1-score   support

           0       0.33      0.89      0.48        63
           1       0.78      0.18      0.30       137

    accuracy                           0.41       200
   macro avg       0.56      0.54      0.39       200
weighted avg       0.64      0.41      0.36       200

