# General and transfer image classification tasks: binary image classification tasks A, B, and C, with full training and transfer learning / frozen weights (Fig 2)

In [None]:
!pip3 uninstall --yes torch torchaudio torchvision torchtext torchdata
!pip3 install torch torchvision torchdata
!pip3 install torchrl

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import sys
sys.path.append('/content/drive/My Drive/networkattention')
# vit model adapted from https://github.com/facebookresearch/dino/blob/main/README.md
import importlib
import torch
import torch.random
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
import torch.utils.data as data_utils
import numpy as np
from PIL import Image
import vision_transformer
import fnmatch
import os
import glob
import shutil
import matplotlib.pyplot as plt
from copy import deepcopy
device = torch.device("cuda")

# IMAGE CLASSIFICATION A, B, AND C DATA LOADERS

traindirA = "/content/drive/My Drive/networkattention/data/train/classificationA"
valdirA = "/content/drive/My Drive/networkattention/data/val/classificationA"

traindirB = "/content/drive/My Drive/networkattention/data/train/classificationB"
valdirB = "/content/drive/My Drive/networkattention/data/val/classificationB"

traindirC = "/content/drive/My Drive/networkattention/data/train/classificationC"
valdirC = "/content/drive/My Drive/networkattention/data/val/classificationC"

train_transforms = transforms.Compose([transforms.Resize((256,256)),
                                       transforms.ToTensor(),
                                       ])
val_transforms = transforms.Compose([transforms.Resize((256,256)),
                                      transforms.ToTensor(),
                                      ])


def schematrain(model, x, y, optimizer):
    pred_attn, h1m, policy = model.forward(x)
    mse = torch.nn.MSELoss()
    bce = torch.nn.BCEWithLogitsLoss()
    pred_loss = 0.05*mse(pred_attn, h1m)
    policy_loss = bce(policy, y)
    total_loss = sum([pred_loss, policy_loss])
    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return total_loss

def schematrain_policy(model, x, y, optimizer):
    pred_attn, h1m, policy = model.forward(x)
    bce = torch.nn.BCEWithLogitsLoss()
    policy_loss = bce(policy, y)
    policy_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return policy_loss

def controltrain(model, x, y, optimizer):
    h1, policy = model.forward(x)
    bce = torch.nn.BCEWithLogitsLoss()
    policy_loss = bce(policy, y)
    policy_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return policy_loss

def fitschema(model, trainloader, valloader, name="", n_epochs=20, policy_only=False):
  bce = torch.nn.BCEWithLogitsLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
  losses = []
  val_losses = []

  epoch_train_losses = []
  epoch_val_losses = []

  for epoch in range(n_epochs):
      epoch_loss = 0
      for i, data in enumerate(trainloader): #iterate over batches
          x_batch, y_batch = data
          x_batch, y_batch = x_batch.to(device), y_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
          model.train()
          if policy_only:
            loss = schematrain_policy(model, x_batch, y_batch, optimizer)
          else:
            loss = schematrain(model, x_batch, y_batch, optimizer)
          epoch_loss += loss.item()/len(trainloader)
          losses.append(loss.item())
          if epoch == 0:
            print(str(i)+": "+str(loss.item())+" / "+str(len(trainloader))+": "+str(epoch_loss))
      epoch_train_losses.append(epoch_loss)
      print('\nEpoch : {}, train loss : {}'.format(epoch+1,epoch_loss))
      with torch.no_grad():
        cum_loss = 0
        for x_batch, y_batch in valloader:
          x_batch = x_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
          y_batch = y_batch.to(device)

          #model to eval mode
          model.eval()

          _, _, policy = model(x_batch)
          val_loss = bce(policy,y_batch)
          cum_loss += val_loss.item()/len(valloader)
          val_losses.append(val_loss.item())

        epoch_val_losses.append(cum_loss)
        print('Epoch : {}, val loss : {}'.format(epoch+1,cum_loss))

        best_loss = min(epoch_val_losses)

        #save best model
        if cum_loss <= best_loss:
          best_model_wts = model.state_dict()

  model.load_state_dict(best_model_wts)

  file = open("/content/drive/My Drive/networkattention/losscurves/"+name+"schema.txt","w")
  for item in epoch_train_losses:
    file.write(str(item)+"\n")
  file.close()

def fitcontrol(model, trainloader, valloader, name="", n_epochs=20):
  bce = torch.nn.BCEWithLogitsLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

  losses = []
  val_losses = []

  epoch_train_losses = []
  epoch_val_losses = []

  for epoch in range(n_epochs):
      epoch_loss = 0
      for i, data in enumerate(trainloader): #iterate over batches
          x_batch, y_batch = data
          x_batch, y_batch = x_batch.to(device), y_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
          model.train()
          loss = controltrain(model, x_batch, y_batch, optimizer)
          epoch_loss += loss.item()/len(trainloader)
          losses.append(loss.item())
          if epoch == 0:
            print(str(i)+": "+str(loss.item())+" / "+str(len(trainloader))+": "+str(epoch_loss))
      epoch_train_losses.append(epoch_loss)
      print('\nEpoch : {}, train loss : {}'.format(epoch+1,epoch_loss))
      with torch.no_grad():
        cum_loss = 0
        for x_batch, y_batch in valloader:
          x_batch = x_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
          y_batch = y_batch.to(device)

          #model to eval mode
          model.eval()

          _, policy = model(x_batch)
          val_loss = bce(policy,y_batch)
          cum_loss += val_loss.item()/len(valloader)
          val_losses.append(val_loss.item())

        epoch_val_losses.append(cum_loss)
        print('Epoch : {}, val loss : {}'.format(epoch+1,cum_loss))

        best_loss = min(epoch_val_losses)

        #save best model
        if cum_loss <= best_loss:
          best_model_wts = model.state_dict()

  model.load_state_dict(best_model_wts)

  file = open("/content/drive/My Drive/networkattention/losscurves/"+name+"control.txt","w")
  for item in epoch_train_losses:
    file.write(str(item)+"\n")
  file.close()

def evaluate(model, valloader, name="", save_attn=False):
  classifications = []
  labels = []
  model.eval()
  sigmoid = torch.nn.Sigmoid()
  total_acc = 0
  for i, data in enumerate(valloader):
      accuracy = 0
      x_batch, y_batch = data
      x_batch, y_batch = x_batch.to(device), y_batch.to(device)
      y_batch = y_batch.unsqueeze(1).float()
      outputs = model.forward(x_batch)
      policy = outputs[-1]
      policy = torch.round(sigmoid(policy))
      accuracy = 1-(torch.sum(abs(policy - y_batch))/len(y_batch))
      total_acc += accuracy.item()/len(valloader)
      for pol in policy:
        classifications.append(pol.item())
      for yb in y_batch:
        labels.append(yb.item())

  file = open("/content/drive/My Drive/networkattention/accuracies/acc"+name+".txt","w")
  file.write(str(total_acc))
  file.close()

  file = open("/content/drive/My Drive/networkattention/classifications/"+name+"_classifications.txt","w")
  file.write(str(classifications))
  file.close()

  file = open("/content/drive/My Drive/networkattention/classifications/"+name+"_labels.txt","w")
  file.write(str(labels))
  file.close()

def freeze_models(models):
  for i, model in enumerate(models):
    for param in model.parameters():
        param.requires_grad = False
    for param in model.policy.parameters():
        param.requires_grad = True


Mounted at /content/drive


In [None]:
for trial in range(0,20):
  torch.random.seed()
  tname = "NEWtrial"+str(trial)
  # FIT / EVALUATE MODELS: IMAGE CLASSIFICATION
  # A
  train_dataA = datasets.ImageFolder(traindirA,transform=train_transforms)
  val_dataA = datasets.ImageFolder(valdirA,transform=val_transforms)

  trainloaderA = torch.utils.data.DataLoader(train_dataA, shuffle = True, batch_size=8)
  valloaderA = torch.utils.data.DataLoader(val_dataA, shuffle = True, batch_size=8)

  modelAschema = vision_transformer.VitAttentionSchema().to(device)
  modelAcontrol = vision_transformer.VitControl().to(device)

  fitschema(modelAschema, trainloaderA, valloaderA, tname+"A")
  evaluate(modelAschema, valloaderA, tname+"Aschema", save_attn=False)
  fitcontrol(modelAcontrol, trainloaderA, valloaderA, tname+"A")
  evaluate(modelAcontrol, valloaderA, tname+"Acontrol", save_attn=False)

  del(train_dataA)
  del(val_dataA)
  del(trainloaderA)
  del(valloaderA)


  # B
  train_dataB = datasets.ImageFolder(traindirB,transform=train_transforms)
  val_dataB = datasets.ImageFolder(valdirB,transform=val_transforms)

  trainloaderB = torch.utils.data.DataLoader(train_dataB, shuffle = True, batch_size=8)
  valloaderB = torch.utils.data.DataLoader(val_dataB, shuffle = True, batch_size=8)

  modelBschema = vision_transformer.VitAttentionSchema().to(device)
  modelBcontrol = vision_transformer.VitControl().to(device)

  fitschema(modelBschema, trainloaderB, valloaderB, tname+"B")
  evaluate(modelBschema, valloaderB, tname+"Bschema", save_attn=False)
  fitcontrol(modelBcontrol, trainloaderB, valloaderB, tname+"B")
  evaluate(modelBcontrol, valloaderB, tname+"Bcontrol", save_attn=False)

  del(train_dataB)
  del(val_dataB)
  del(trainloaderB)
  del(valloaderB)

  # C
  train_dataC = datasets.ImageFolder(traindirC,transform=train_transforms)
  val_dataC = datasets.ImageFolder(valdirC,transform=val_transforms)

  trainloaderC = torch.utils.data.DataLoader(train_dataC, shuffle = True, batch_size=8)
  valloaderC = torch.utils.data.DataLoader(val_dataC, shuffle = True, batch_size=8)

  modelCschema = vision_transformer.VitAttentionSchema().to(device)
  modelCcontrol = vision_transformer.VitControl().to(device)

  fitschema(modelCschema, trainloaderC, valloaderC, tname+"C")
  evaluate(modelCschema, valloaderC, tname+"Cschema", save_attn=False)
  fitcontrol(modelCcontrol, trainloaderC, valloaderC, tname+"C")
  evaluate(modelCcontrol, valloaderC, tname+"Ccontrol", save_attn=False)

  del(train_dataC)
  del(val_dataC)
  del(trainloaderC)
  del(valloaderC)

  # FREEZE THE MODELS
  freeze_models([modelAschema, modelAcontrol, modelBschema, modelBcontrol, modelCschema, modelCcontrol])

  # TRANSFER LEARNING
  train_dataB = datasets.ImageFolder(traindirB,transform=train_transforms)
  val_dataB = datasets.ImageFolder(valdirB,transform=val_transforms)

  trainloaderB = torch.utils.data.DataLoader(train_dataB, shuffle = True, batch_size=8)
  valloaderB = torch.utils.data.DataLoader(val_dataB, shuffle = True, batch_size=8)

  fitschema(modelAschema, trainloaderB, valloaderB, tname+"AschemaB")
  evaluate(modelAschema, valloaderB, tname+"AschemaB", policy_only=True)

  fitcontrol(modelAcontrol, trainloaderB, valloaderB, tname+"AcontrolB")
  evaluate(modelAschema, valloaderB, tname+"AcontrolB")

  del(train_dataB)
  del(val_dataB)
  del(trainloaderB)
  del(valloaderB)

  del(modelAschema)
  del(modelAcontrol)

  train_dataC = datasets.ImageFolder(traindirC,transform=train_transforms)
  val_dataC = datasets.ImageFolder(valdirC,transform=val_transforms)

  trainloaderC = torch.utils.data.DataLoader(train_dataC, shuffle = True, batch_size=8)
  valloaderC = torch.utils.data.DataLoader(val_dataC, shuffle = True, batch_size=8)

  fitschema(modelBschema, trainloaderC, valloaderC, tname+"BschemaC")
  evaluate(modelBschema, valloaderC, tname+"BschemaC", policy_only=True)

  fitcontrol(modelBcontrol, trainloaderC, valloaderC, tname+"BcontrolC")
  evaluate(modelBcontrol, valloaderC, tname+"BcontrolC")

  del(train_dataC)
  del(val_dataC)
  del(trainloaderC)
  del(valloaderC)

  del(modelBschema)
  del(modelBcontrol)

  train_dataA = datasets.ImageFolder(traindirA,transform=train_transforms)
  val_dataA = datasets.ImageFolder(valdirA,transform=val_transforms)

  trainloaderA = torch.utils.data.DataLoader(train_dataA, shuffle = True, batch_size=8)
  valloaderA = torch.utils.data.DataLoader(val_dataA, shuffle = True, batch_size=8)

  fitschema(modelCschema, trainloaderA, valloaderA, tname+"CschemaA")
  evaluate(modelCschema, valloaderA, tname+"CschemaA", policy_only=True)

  fitcontrol(modelCcontrol, trainloaderA, valloaderA, tname+"CcontrolA")
  evaluate(modelCcontrol, valloaderA, tname+"CcontrolA")

  del(train_dataA)
  del(val_dataA)
  del(trainloaderA)
  del(valloaderA)

  del(modelCschema)
  del(modelCcontrol)




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
61: 4.5884857177734375 / 239: 1.2105814336982232
62: 3.958615779876709 / 239: 1.2271446796391299
63: 2.531510829925537 / 239: 1.237736775161831
64: 1.907580852508545 / 239: 1.2457182850049628
65: 4.097668170928955 / 239: 1.2628633401134521
66: 2.2190558910369873 / 239: 1.2721480927956152
67: 5.5682053565979 / 239: 1.2954460231579497
68: 3.613309621810913 / 239: 1.3105644734584136
69: 3.0307812690734863 / 239: 1.3232455666344534
70: 9.032301902770996 / 239: 1.3610376248050433
71: 5.214232444763184 / 239: 1.382854496958864
72: 1.7016751766204834 / 239: 1.3899744767773599
73: 1.406920313835144 / 239: 1.3958611726511472
74: 6.43858528137207 / 239: 1.4228008600209048
75: 6.999792098999023 / 239: 1.4520886930711099
76: 3.2254486083984375 / 239: 1.4655842939430699
77: 6.765417575836182 / 239: 1.4938914804528447
78: 4.7683634757995605 / 239: 1.5138427920670687
79: 3.108198404312134 / 239: 1.5268478063110527
80: 2.4469153881073 / 

KeyboardInterrupt: 