# Experiments 1 and 2, general and transfer image classification tasks: binary image classification tasks A, B, and C, with full training and transfer learning / frozen weights (Fig 3)

In [None]:
# To run the notebook, change the string below
dir_path = '/path/to/my/directory'

# Uncomment to run in Google Colab
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)
import sys
sys.path.append(dir_path)

# ViT model based on https://github.com/facebookresearch/dino/blob/main/README.md
import importlib
import torch
import torch.random
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
import torch.utils.data as data_utils
import numpy as np
from PIL import Image
import vision_transformer
import fnmatch
import os
import glob
import shutil
import matplotlib.pyplot as plt
from copy import deepcopy
device = torch.device("cuda")

# IMAGE CLASSIFICATION A, B, AND C DATA LOADERS

traindirA = dir_path+"/data/train/classificationA"
valdirA = dir_path+"/data/val/classificationA"

traindirB = dir_path+"/data/train/classificationB"
valdirB = dir_path+"/data/val/classificationB"

traindirC = dir_path+"/data/train/classificationC"
valdirC = dir_path+"/data/val/classificationC"

train_transforms = transforms.Compose([transforms.Resize((256,256)),
                                       transforms.ToTensor(),
                                       ])
val_transforms = transforms.Compose([transforms.Resize((256,256)),
                                      transforms.ToTensor(),
                                      ])


def schematrain(model, x, y, optimizer):
    pred_attn, h1m, policy = model.forward(x)
    mse = torch.nn.MSELoss()
    bce = torch.nn.BCEWithLogitsLoss()
    pred_loss = 0.05*mse(pred_attn, h1m)
    policy_loss = bce(policy, y)
    total_loss = sum([pred_loss, policy_loss])
    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return total_loss

def schematrain_policy(model, x, y, optimizer):
    pred_attn, h1m, policy = model.forward(x)
    bce = torch.nn.BCEWithLogitsLoss()
    policy_loss = bce(policy, y)
    policy_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return policy_loss

def controltrain(model, x, y, optimizer):
    h1, policy = model.forward(x)
    bce = torch.nn.BCEWithLogitsLoss()
    policy_loss = bce(policy, y)
    policy_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return policy_loss

def fitschema(model, trainloader, valloader, name="", n_epochs=20, policy_only=False):
  bce = torch.nn.BCEWithLogitsLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
  losses = []
  val_losses = []

  epoch_train_losses = []
  epoch_val_losses = []

  for epoch in range(n_epochs):
      epoch_loss = 0
      for i, data in enumerate(trainloader): # iterate over batches
          x_batch, y_batch = data
          x_batch, y_batch = x_batch.to(device), y_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float()
          model.train()
          if policy_only:
            loss = schematrain_policy(model, x_batch, y_batch, optimizer)
          else:
            loss = schematrain(model, x_batch, y_batch, optimizer)
          epoch_loss += loss.item()/len(trainloader)
          losses.append(loss.item())
          if epoch == 0:
            print(str(i)+": "+str(loss.item())+" / "+str(len(trainloader))+": "+str(epoch_loss))
      epoch_train_losses.append(epoch_loss)
      print('\nEpoch : {}, train loss : {}'.format(epoch+1,epoch_loss))
      with torch.no_grad():
        total_loss = 0
        for x_batch, y_batch in valloader:
          x_batch = x_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float()
          y_batch = y_batch.to(device)

          model.eval()

          _, _, policy = model(x_batch)
          val_loss = bce(policy,y_batch)
          total_loss += val_loss.item()/len(valloader)
          val_losses.append(val_loss.item())

        epoch_val_losses.append(total_loss)
        print('Epoch : {}, val loss : {}'.format(epoch+1,total_loss))

        best_loss = min(epoch_val_losses)

        # save best model
        if total_loss <= best_loss:
          best_model_wts = model.state_dict()

  model.load_state_dict(best_model_wts)

def fitcontrol(model, trainloader, valloader, name="", n_epochs=20):
  bce = torch.nn.BCEWithLogitsLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

  losses = []
  val_losses = []

  epoch_train_losses = []
  epoch_val_losses = []

  for epoch in range(n_epochs):
      epoch_loss = 0
      for i, data in enumerate(trainloader): # iterate over batches
          x_batch, y_batch = data
          x_batch, y_batch = x_batch.to(device), y_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float()
          model.train()
          loss = controltrain(model, x_batch, y_batch, optimizer)
          epoch_loss += loss.item()/len(trainloader)
          losses.append(loss.item())
          if epoch == 0:
            print(str(i)+": "+str(loss.item())+" / "+str(len(trainloader))+": "+str(epoch_loss))
      epoch_train_losses.append(epoch_loss)
      print('\nEpoch : {}, train loss : {}'.format(epoch+1,epoch_loss))
      with torch.no_grad():
        total_loss = 0
        for x_batch, y_batch in valloader:
          x_batch = x_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float()
          y_batch = y_batch.to(device)

          model.eval()

          _, policy = model(x_batch)
          val_loss = bce(policy,y_batch)
          total_loss += val_loss.item()/len(valloader)
          val_losses.append(val_loss.item())

        epoch_val_losses.append(total_loss)
        print('Epoch : {}, val loss : {}'.format(epoch+1,total_loss))

        best_loss = min(epoch_val_losses)

        # save best model
        if total_loss <= best_loss:
          best_model_wts = model.state_dict()

  model.load_state_dict(best_model_wts)

def evaluate(model, valloader, name="", save_attn=False):
  classifications = []
  labels = []
  model.eval()
  sigmoid = torch.nn.Sigmoid()
  total_acc = 0
  for i, data in enumerate(valloader):
      accuracy = 0
      x_batch, y_batch = data
      x_batch, y_batch = x_batch.to(device), y_batch.to(device)
      y_batch = y_batch.unsqueeze(1).float()
      outputs = model.forward(x_batch)
      policy = outputs[-1]
      policy = torch.round(sigmoid(policy))
      accuracy = 1-(torch.sum(abs(policy - y_batch))/len(y_batch))
      total_acc += accuracy.item()/len(valloader)
      for pol in policy:
        classifications.append(pol.item())
      for yb in y_batch:
        labels.append(yb.item())

  file = open(dir_path+"/accuracy_"+name+".txt","w")
  file.write(str(total_acc))
  file.close()

def freeze_models(models):
  for i, model in enumerate(models):
    for param in model.parameters():
        param.requires_grad = False
    for param in model.policy.parameters():
        param.requires_grad = True


In [None]:
seedn = 0
seeds = [43250, 58038, 70991, 85884, 88252, 98122, 59732, 59721, 34361,
         24375, 17167, 25532, 24606, 27055, 77062, 27850, 93109, 37718,
         70332, 75087]

for trial in range(0,20):
  torch.manual_seed(seeds[seedn])
  tname = "trial"+str(trial)
  # FIT / EVALUATE MODELS: IMAGE CLASSIFICATION
  # A
  train_dataA = datasets.ImageFolder(traindirA,transform=train_transforms)
  val_dataA = datasets.ImageFolder(valdirA,transform=val_transforms)

  trainloaderA = torch.utils.data.DataLoader(train_dataA, shuffle = True, batch_size=8)
  valloaderA = torch.utils.data.DataLoader(val_dataA, shuffle = True, batch_size=8)

  modelAschema = vision_transformer.VitAttentionSchema().to(device)
  modelAcontrol = vision_transformer.VitControl().to(device)

  fitschema(modelAschema, trainloaderA, valloaderA, tname+"A")
  evaluate(modelAschema, valloaderA, tname+"Aschema", save_attn=False)
  fitcontrol(modelAcontrol, trainloaderA, valloaderA, tname+"A")
  evaluate(modelAcontrol, valloaderA, tname+"Acontrol", save_attn=False)

  del(train_dataA)
  del(val_dataA)
  del(trainloaderA)
  del(valloaderA)


  # B
  train_dataB = datasets.ImageFolder(traindirB,transform=train_transforms)
  val_dataB = datasets.ImageFolder(valdirB,transform=val_transforms)

  trainloaderB = torch.utils.data.DataLoader(train_dataB, shuffle = True, batch_size=8)
  valloaderB = torch.utils.data.DataLoader(val_dataB, shuffle = True, batch_size=8)

  modelBschema = vision_transformer.VitAttentionSchema().to(device)
  modelBcontrol = vision_transformer.VitControl().to(device)

  fitschema(modelBschema, trainloaderB, valloaderB, tname+"B")
  evaluate(modelBschema, valloaderB, tname+"Bschema", save_attn=False)
  fitcontrol(modelBcontrol, trainloaderB, valloaderB, tname+"B")
  evaluate(modelBcontrol, valloaderB, tname+"Bcontrol", save_attn=False)

  del(train_dataB)
  del(val_dataB)
  del(trainloaderB)
  del(valloaderB)

  # C
  train_dataC = datasets.ImageFolder(traindirC,transform=train_transforms)
  val_dataC = datasets.ImageFolder(valdirC,transform=val_transforms)

  trainloaderC = torch.utils.data.DataLoader(train_dataC, shuffle = True, batch_size=8)
  valloaderC = torch.utils.data.DataLoader(val_dataC, shuffle = True, batch_size=8)

  modelCschema = vision_transformer.VitAttentionSchema().to(device)
  modelCcontrol = vision_transformer.VitControl().to(device)

  fitschema(modelCschema, trainloaderC, valloaderC, tname+"C")
  evaluate(modelCschema, valloaderC, tname+"Cschema", save_attn=False)
  fitcontrol(modelCcontrol, trainloaderC, valloaderC, tname+"C")
  evaluate(modelCcontrol, valloaderC, tname+"Ccontrol", save_attn=False)

  del(train_dataC)
  del(val_dataC)
  del(trainloaderC)
  del(valloaderC)

  # FREEZE THE MODELS
  freeze_models([modelAschema, modelAcontrol, modelBschema, modelBcontrol, modelCschema, modelCcontrol])

  # TRANSFER LEARNING
  train_dataB = datasets.ImageFolder(traindirB,transform=train_transforms)
  val_dataB = datasets.ImageFolder(valdirB,transform=val_transforms)

  trainloaderB = torch.utils.data.DataLoader(train_dataB, shuffle = True, batch_size=8)
  valloaderB = torch.utils.data.DataLoader(val_dataB, shuffle = True, batch_size=8)

  fitschema(modelAschema, trainloaderB, valloaderB, tname+"AschemaB", policy_only=True)
  evaluate(modelAschema, valloaderB, tname+"AschemaB")

  fitcontrol(modelAcontrol, trainloaderB, valloaderB, tname+"AcontrolB")
  evaluate(modelAschema, valloaderB, tname+"AcontrolB")

  del(train_dataB)
  del(val_dataB)
  del(trainloaderB)
  del(valloaderB)

  del(modelAschema)
  del(modelAcontrol)

  train_dataC = datasets.ImageFolder(traindirC,transform=train_transforms)
  val_dataC = datasets.ImageFolder(valdirC,transform=val_transforms)

  trainloaderC = torch.utils.data.DataLoader(train_dataC, shuffle = True, batch_size=8)
  valloaderC = torch.utils.data.DataLoader(val_dataC, shuffle = True, batch_size=8)

  fitschema(modelBschema, trainloaderC, valloaderC, tname+"BschemaC", policy_only=True)
  evaluate(modelBschema, valloaderC, tname+"BschemaC")

  fitcontrol(modelBcontrol, trainloaderC, valloaderC, tname+"BcontrolC")
  evaluate(modelBcontrol, valloaderC, tname+"BcontrolC")

  del(train_dataC)
  del(val_dataC)
  del(trainloaderC)
  del(valloaderC)

  del(modelBschema)
  del(modelBcontrol)

  train_dataA = datasets.ImageFolder(traindirA,transform=train_transforms)
  val_dataA = datasets.ImageFolder(valdirA,transform=val_transforms)

  trainloaderA = torch.utils.data.DataLoader(train_dataA, shuffle = True, batch_size=8)
  valloaderA = torch.utils.data.DataLoader(val_dataA, shuffle = True, batch_size=8)

  fitschema(modelCschema, trainloaderA, valloaderA, tname+"CschemaA", policy_only=True)
  evaluate(modelCschema, valloaderA, tname+"CschemaA")

  fitcontrol(modelCcontrol, trainloaderA, valloaderA, tname+"CcontrolA")
  evaluate(modelCcontrol, valloaderA, tname+"CcontrolA")

  del(train_dataA)
  del(val_dataA)
  del(trainloaderA)
  del(valloaderA)

  del(modelCschema)
  del(modelCcontrol)

  seedn += 1

