# Class token attention task: the discrimination of genuine and artificially shuffled class tokens (Fig 4)

In [None]:
!pip3 uninstall --yes torch torchaudio torchvision torchtext torchdata
!pip3 install torch torchvision torchdata
!pip3 install torchrl

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import sys
sys.path.append('/content/drive/My Drive/networkattention')
# vit model from https://github.com/facebookresearch/dino/blob/main/README.md
import importlib
import torch
import torch.random
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
import torch.utils.data as data_utils
import numpy as np
from PIL import Image
import vision_transformer_clspred
import fnmatch
import os
import glob
import shutil
import matplotlib.pyplot as plt
from copy import deepcopy
device = torch.device("cuda")

# IMAGE CLASSIFICATION A, B, AND C DATA LOADERS

traindirA = "/content/drive/My Drive/networkattention/data/train/classificationA"
valdirA = "/content/drive/My Drive/networkattention/data/val/classificationA"

traindirB = "/content/drive/My Drive/networkattention/data/train/classificationB"
valdirB = "/content/drive/My Drive/networkattention/data/val/classificationB"

traindirC = "/content/drive/My Drive/networkattention/data/train/classificationC"
valdirC = "/content/drive/My Drive/networkattention/data/val/classificationC"

train_transforms = transforms.Compose([transforms.Resize((256,256)),
                                       transforms.ToTensor(),
                                       ])
val_transforms = transforms.Compose([transforms.Resize((256,256)),
                                      transforms.ToTensor(),
                                      ])


def schematrain(model, x, y, optimizer):
    pred_attn, h1m, policy = model.forward(x)
    mse = torch.nn.MSELoss()
    bce = torch.nn.BCEWithLogitsLoss()
    pred_loss = 0.05*mse(pred_attn, h1m)
    policy_loss = bce(policy, y)
    total_loss = sum([pred_loss, policy_loss])
    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return total_loss

def schematrain_policy(model, x, y, optimizer):
    pred_attn, h1m, policy = model.forward(x)
    bce = torch.nn.BCEWithLogitsLoss()
    policy_loss = bce(policy, y)
    policy_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return policy_loss

def controltrain(model, x, y, optimizer):
    h1, policy = model.forward(x)
    bce = torch.nn.BCEWithLogitsLoss()
    policy_loss = bce(policy, y)
    policy_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return policy_loss

def fitschema(model, trainloader, valloader, name="", n_epochs=20, policy_only=False):
  bce = torch.nn.BCEWithLogitsLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
  losses = []
  val_losses = []

  epoch_train_losses = []
  epoch_val_losses = []

  for epoch in range(n_epochs):
      epoch_loss = 0
      for i, data in enumerate(trainloader): #iterate over batches
          x_batch, y_batch = data
          x_batch, y_batch = x_batch.to(device), y_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
          model.train()
          if policy_only:
            loss = schematrain_policy(model, x_batch, y_batch, optimizer)
          else:
            loss = schematrain(model, x_batch, y_batch, optimizer)
          epoch_loss += loss.item()/len(trainloader)
          losses.append(loss.item())
          if epoch == 0:
            print(str(i)+": "+str(loss.item())+" / "+str(len(trainloader))+": "+str(epoch_loss))
      epoch_train_losses.append(epoch_loss)
      print('\nEpoch : {}, train loss : {}'.format(epoch+1,epoch_loss))
      with torch.no_grad():
        cum_loss = 0
        for x_batch, y_batch in valloader:
          x_batch = x_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
          y_batch = y_batch.to(device)

          #model to eval mode
          model.eval()

          _, _, policy = model(x_batch)
          val_loss = bce(policy,y_batch)
          cum_loss += val_loss.item()/len(valloader)
          val_losses.append(val_loss.item())

        epoch_val_losses.append(cum_loss)
        print('Epoch : {}, val loss : {}'.format(epoch+1,cum_loss))

        best_loss = min(epoch_val_losses)

        #save best model
        if cum_loss <= best_loss:
          best_model_wts = model.state_dict()

  model.load_state_dict(best_model_wts)

  file = open("/content/drive/My Drive/networkattention/losscurves/"+name+"schema.txt","w")
  for item in epoch_train_losses:
    file.write(str(item)+"\n")
  file.close()

def fitcontrol(model, trainloader, valloader, name="", n_epochs=20):
  bce = torch.nn.BCEWithLogitsLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

  losses = []
  val_losses = []

  epoch_train_losses = []
  epoch_val_losses = []

  for epoch in range(n_epochs):
      epoch_loss = 0
      for i, data in enumerate(trainloader): #iterate over batches
          x_batch, y_batch = data
          x_batch, y_batch = x_batch.to(device), y_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
          model.train()
          loss = controltrain(model, x_batch, y_batch, optimizer)
          epoch_loss += loss.item()/len(trainloader)
          losses.append(loss.item())
          if epoch == 0:
            print(str(i)+": "+str(loss.item())+" / "+str(len(trainloader))+": "+str(epoch_loss))
      epoch_train_losses.append(epoch_loss)
      print('\nEpoch : {}, train loss : {}'.format(epoch+1,epoch_loss))
      with torch.no_grad():
        cum_loss = 0
        for x_batch, y_batch in valloader:
          x_batch = x_batch.to(device)
          y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
          y_batch = y_batch.to(device)

          #model to eval mode
          model.eval()

          _, policy = model(x_batch)
          val_loss = bce(policy,y_batch)
          cum_loss += val_loss.item()/len(valloader)
          val_losses.append(val_loss.item())

        epoch_val_losses.append(cum_loss)
        print('Epoch : {}, val loss : {}'.format(epoch+1,cum_loss))

        best_loss = min(epoch_val_losses)

        #save best model
        if cum_loss <= best_loss:
          best_model_wts = model.state_dict()

  model.load_state_dict(best_model_wts)

  file = open("/content/drive/My Drive/networkattention/losscurves/"+name+"control.txt","w")
  for item in epoch_train_losses:
    file.write(str(item)+"\n")
  file.close()

def evaluate(model, valloader, name="", save_attn=False):
  classifications = []
  labels = []
  attn_outputs = torch.empty((0,3,257,257)).to(device)
  model.eval()
  sigmoid = torch.nn.Sigmoid()
  total_acc = 0
  for i, data in enumerate(valloader):
      accuracy = 0
      x_batch, y_batch = data
      x_batch, y_batch = x_batch.to(device), y_batch.to(device)
      y_batch = y_batch.unsqueeze(1).float()
      outputs = model.forward(x_batch)
      policy = outputs[-1]
      policy = torch.round(sigmoid(policy))
      accuracy = 1-(torch.sum(abs(policy - y_batch))/len(y_batch))
      total_acc += accuracy.item()/len(valloader)
      if save_attn:
        attn = model.get_last_selfattention(x_batch)
        attn_outputs = torch.cat((attn_outputs, attn.data[:,0:3,:,:]), 0)
      for pol in policy:
        classifications.append(pol.item())
      for yb in y_batch:
        labels.append(yb.item())

  file = open("/content/drive/My Drive/networkattention/accuracies/acc"+name+".txt","w")
  file.write(str(total_acc))
  file.close()

  file = open("/content/drive/My Drive/networkattention/classifications/"+name+"_classifications.txt","w")
  file.write(str(classifications))
  file.close()

  file = open("/content/drive/My Drive/networkattention/classifications/"+name+"_labels.txt","w")
  file.write(str(labels))
  file.close()

  if save_attn:
    torch.save(attn_outputs, "/content/drive/My Drive/networkattention/data/attentions/"+name+"attn.pt")

def freeze_models(models):
  for i, model in enumerate(models):
    for param in model.parameters():
        param.requires_grad = False
    for param in model.policy.parameters():
        param.requires_grad = True


Mounted at /content/drive


In [None]:
schem = vision_transformer_clspred.VitAttentionSchema().to(device)
cont = vision_transformer_clspred.VitControl()

attn = torch.load("/content/drive/My Drive/networkattention/data/attentions/modelAschemaattn.pt")

clsattn = torch.nn.functional.interpolate(attn[:,:,0,:], size=(65536), mode="nearest")
cls = clsattn.reshape(730,3,256,256)
print(cls.shape)
schem(cls[0,...].unsqueeze(0))

In [None]:
tname = "NEWCLStrial0"
torch.random.seed()
# B
train_dataB = datasets.ImageFolder(traindirB,transform=train_transforms)
val_dataB = datasets.ImageFolder(valdirB,transform=val_transforms)

trainloaderB = torch.utils.data.DataLoader(train_dataB, shuffle = True, batch_size=8)
valloaderB = torch.utils.data.DataLoader(val_dataB, shuffle = True, batch_size=8)

modelBschema = vision_transformer_clspred.VitAttentionSchema().to(device)
modelBcontrol = vision_transformer_clspred.VitControl().to(device)

fitschema(modelBschema, trainloaderB, valloaderB, tname+"B")
fitcontrol(modelBcontrol, trainloaderB, valloaderB, tname+"B")

evaluate(modelBschema, valloaderB, tname+"Bschema", save_attn=True)
evaluate(modelBcontrol, valloaderB, tname+"Bcontrol", save_attn=True)

del(train_dataB)
del(val_dataB)
del(trainloaderB)
del(valloaderB)

# C
train_dataC = datasets.ImageFolder(traindirC,transform=train_transforms)
val_dataC = datasets.ImageFolder(valdirC,transform=val_transforms)

trainloaderC = torch.utils.data.DataLoader(train_dataC, shuffle = True, batch_size=8)
valloaderC = torch.utils.data.DataLoader(val_dataC, shuffle = True, batch_size=8)

modelCschema = vision_transformer_clspred.VitAttentionSchema().to(device)
modelCcontrol = vision_transformer_clspred.VitControl().to(device)
fitschema(modelCschema, trainloaderC, valloaderC, tname+"C")
fitcontrol(modelCcontrol, trainloaderC, valloaderC, tname+"C")

evaluate(modelCschema, valloaderC, tname+"Cschema", save_attn=True)
evaluate(modelCcontrol, valloaderC, tname+"Ccontrol", save_attn=True)

del(train_dataC)
del(val_dataC)
del(trainloaderC)
del(valloaderC)

0: 0.6646795272827148 / 239: 0.002781085888212196
1: 0.6534203886985779 / 239: 0.005515062409963568
2: 0.8286062479019165 / 239: 0.008982034158507152
3: 0.6653600931167603 / 239: 0.011765967602510332
4: 0.7489864826202393 / 239: 0.01489980225782514
5: 0.6522353887557983 / 239: 0.017628820620820114
6: 0.6897300481796265 / 239: 0.020514720403998467
7: 0.7265724539756775 / 239: 0.02355477251268331
8: 0.6764973998069763 / 239: 0.026385305566268984
9: 0.6736147999763489 / 239: 0.029203777532697223
10: 0.6753988862037659 / 239: 0.03202971429505608
11: 0.6847952008247375 / 239: 0.03489496618135205
12: 0.6776763200759888 / 239: 0.037730431955728574
13: 0.6642135381698608 / 239: 0.040509568098698705
14: 0.7013245820999146 / 239: 0.043443980576104206
15: 0.679368257522583 / 239: 0.046286525586658946
16: 0.7302337884902954 / 239: 0.04934189708661834
17: 0.6957350969314575 / 239: 0.052252922596791804
18: 0.5531455874443054 / 239: 0.05456733928065919
19: 0.8296688199043274 / 239: 0.0580387569371626

In [None]:
seedn = 7
seeds = [38333, 79984, 66390, 18079, 15244, 36378, 65024, 16278, 96089, 80817,
         15820, 84086, 81869, 42235, 78231, 13501, 62904, 81499, 70093, 45182,
         36201, 31709, 37794, 82643, 49473, 23698, 25148, 82890, 75009, 40721,
         39663, 45414, 51407, 17793, 66530, 19383, 86574, 59920, 13187, 82918,
         64587, 84040, 14681, 31941, 68987, 82403, 23492, 54427, 29316, 38279,
         75782, 12617, 36385, 71490, 12924, 51894, 29646, 15556, 52387, 47171,
         36631, 64924, 32661, 60206, 61079, 38763, 47185, 51165, 37078, 91493,
         15507, 94325, 38695, 18167, 55667, 23816, 55812, 55002, 84690, 91869,
         70293, 83810, 78368, 18832, 89594, 86696, 76356, 86349, 35784, 54066,
         23793, 93586, 78475, 52226, 92427, 23244, 35298, 61167, 93357, 11480,
         32909, 91040, 25266, 43908, 75798, 44285, 50841, 25611, 48329, 40202,
         39898, 21002, 73790, 94005, 71650, 47142, 22625, 12788, 99651, 62684,
         21942, 52884, 91096, 64224, 22993, 59695, 28217, 42843, 53415, 20168,
         71887, 75207, 81045, 55057, 96637, 17166, 61124, 93537, 19632, 47096,
         77510, 26684, 64662, 81874, 76832, 66655, 75951, 26273, 26645, 91141]

for trial in range(1,5):
  torch.manual_seed(seeds[seedn])
  tname = "NEWCLStrial"+str(trial)

  # FIT / EVALUATE MODELS: IMAGE CLASSIFICATION
  # A
  train_dataA = datasets.ImageFolder(traindirA,transform=train_transforms)
  val_dataA = datasets.ImageFolder(valdirA,transform=val_transforms)

  trainloaderA = torch.utils.data.DataLoader(train_dataA, shuffle = True, batch_size=8)
  valloaderA = torch.utils.data.DataLoader(val_dataA, shuffle = True, batch_size=8)

  modelAschema = vision_transformer_clspred.VitAttentionSchema().to(device)
  modelAcontrol = vision_transformer_clspred.VitControl().to(device)

  fitschema(modelAschema, trainloaderA, valloaderA, tname+"A")
  fitcontrol(modelAcontrol, trainloaderA, valloaderA, tname+"A")

  evaluate(modelAschema, valloaderA, tname+"Aschema", save_attn=False)
  evaluate(modelAcontrol, valloaderA, tname+"Acontrol", save_attn=False)

  del(train_dataA)
  del(val_dataA)
  del(trainloaderA)
  del(valloaderA)

  # B
  train_dataB = datasets.ImageFolder(traindirB,transform=train_transforms)
  val_dataB = datasets.ImageFolder(valdirB,transform=val_transforms)

  trainloaderB = torch.utils.data.DataLoader(train_dataB, shuffle = True, batch_size=8)
  valloaderB = torch.utils.data.DataLoader(val_dataB, shuffle = True, batch_size=8)

  modelBschema = vision_transformer_clspred.VitAttentionSchema().to(device)
  modelBcontrol = vision_transformer_clspred.VitControl().to(device)

  fitschema(modelBschema, trainloaderB, valloaderB, tname+"B")
  fitcontrol(modelBcontrol, trainloaderB, valloaderB, tname+"B")

  evaluate(modelBschema, valloaderB, tname+"Bschema", save_attn=False)
  evaluate(modelBcontrol, valloaderB, tname+"Bcontrol", save_attn=False)

  del(train_dataB)
  del(val_dataB)
  del(trainloaderB)
  del(valloaderB)

  # C
  train_dataC = datasets.ImageFolder(traindirC,transform=train_transforms)
  val_dataC = datasets.ImageFolder(valdirC,transform=val_transforms)

  trainloaderC = torch.utils.data.DataLoader(train_dataC, shuffle = True, batch_size=8)
  valloaderC = torch.utils.data.DataLoader(val_dataC, shuffle = True, batch_size=8)

  modelCschema = vision_transformer_clspred.VitAttentionSchema().to(device)
  modelCcontrol = vision_transformer_clspred.VitControl().to(device)
  fitschema(modelCschema, trainloaderC, valloaderC, tname+"C")
  fitcontrol(modelCcontrol, trainloaderC, valloaderC, tname+"C")

  evaluate(modelCschema, valloaderC, tname+"Cschema", save_attn=False)
  evaluate(modelCcontrol, valloaderC, tname+"Ccontrol", save_attn=False)

  del(train_dataC)
  del(val_dataC)
  del(trainloaderC)
  del(valloaderC)

  # FREEZE THE MODELS
  freeze_models([modelAschema, modelAcontrol, modelBschema, modelBcontrol, modelCschema, modelCcontrol])

  modelAschema_wts = deepcopy(modelAschema.state_dict())
  modelAcontrol_wts = deepcopy(modelAcontrol.state_dict())
  modelBschema_wts = deepcopy(modelBschema.state_dict())
  modelBcontrol_wts = deepcopy(modelBcontrol.state_dict())
  modelCschema_wts = deepcopy(modelCschema.state_dict())
  modelCcontrol_wts = deepcopy(modelCcontrol.state_dict())

  # FIT / EVALUATE MODELS: CLS ATTENTION CLASSIFICATION
  # SCHEMA
  batch_size = 8

  schemaAattn = torch.load("/content/drive/My Drive/networkattention/data/attentions/CLStrial0Aschemaattn.pt")
  schemaAattn = torch.nn.functional.interpolate(schemaAattn[:,:,0,:], size=(65536), mode="nearest")
  schemaAattn = schemaAattn.reshape(730,3,256,256)

  false_schemaAattn = torch.clone(schemaAattn)
  indices = torch.randperm(false_schemaAattn.shape[-1])
  false_schemaAattn = false_schemaAattn[:,:,indices] # false attention values are shuffled along last dimension

  dataset_schemaAattn = TensorDataset(torch.cat((schemaAattn, false_schemaAattn),0),
                                      torch.cat((torch.ones(730,), torch.zeros(730,)),0))
  dataset_schemaAattn = torch.utils.data.random_split(dataset_schemaAattn, [0.9, 0.1])
  schemaAattntrain = DataLoader(dataset_schemaAattn[0], batch_size, shuffle=True)
  schemaAattnval = DataLoader(dataset_schemaAattn[1], batch_size, shuffle=True)

  fitcontrol(modelCcontrol, schemaAattntrain, schemaAattnval, tname+"CAcontrol_schemaattn", n_epochs=800)
  evaluate(modelCcontrol, schemaAattnval, tname+"CAcontrol_schemaattn", save_attn=False)
  fitschema(modelCschema, schemaAattntrain, schemaAattnval, tname+"CAschema_schemaattn", n_epochs=800, policy_only=True)
  evaluate(modelCschema, schemaAattnval, tname+"CAschema_schemaattn", save_attn=False)

  seedn += 1
  torch.manual_seed(seeds[seedn])

  del(schemaAattn)
  del(false_schemaAattn)
  del(indices)
  del(dataset_schemaAattn)
  del(schemaAattntrain)
  del(schemaAattnval)

  schemaBattn = torch.load("/content/drive/My Drive/networkattention/data/attentions/CLStrial0Bschemaattn.pt")
  schemaBattn = torch.nn.functional.interpolate(schemaBattn[:,:,0,:], size=(65536), mode="nearest")
  schemaBattn = schemaBattn.reshape(730,3,256,256)

  false_schemaBattn = torch.clone(schemaBattn)
  indices = torch.randperm(false_schemaBattn.shape[-1])
  false_schemaBattn = false_schemaBattn[:,:,indices]

  dataset_schemaBattn = TensorDataset(torch.cat((schemaBattn, false_schemaBattn),0),
                                      torch.cat((torch.ones(730,), torch.zeros(730,)),0))
  dataset_schemaBattn = torch.utils.data.random_split(dataset_schemaBattn, [0.9, 0.1])
  schemaBattntrain = DataLoader(dataset_schemaBattn[0], batch_size, shuffle=True)
  schemaBattnval = DataLoader(dataset_schemaBattn[1], batch_size, shuffle=True)

  fitschema(modelAschema, schemaBattntrain, schemaBattnval, tname+"ABschema_schemaattn", n_epochs=800, policy_only=True)
  fitcontrol(modelAcontrol, schemaBattntrain, schemaBattnval, tname+"ABcontrol_schemaattn", n_epochs=800)

  evaluate(modelAschema, schemaBattnval, tname+"ABschema_schemaattn", save_attn=False)
  evaluate(modelAcontrol, schemaBattnval, tname+"ABcontrol_schemaattn", save_attn=False)

  seedn += 1
  torch.manual_seed(seeds[seedn])

  del(schemaBattn)
  del(false_schemaBattn)
  del(indices)
  del(dataset_schemaBattn)
  del(schemaBattntrain)
  del(schemaBattnval)

  schemaCattn = torch.load("/content/drive/My Drive/networkattention/data/attentions/CLStrial0Cschemaattn.pt")
  schemaCattn = torch.nn.functional.interpolate(schemaCattn[:,:,0,:], size=(65536), mode="nearest")
  schemaCattn = schemaCattn.reshape(730,3,256,256)

  false_schemaCattn = torch.clone(schemaCattn)
  indices = torch.randperm(false_schemaCattn.shape[-1])
  false_schemaCattn = false_schemaCattn[:,:,indices]

  dataset_schemaCattn = TensorDataset(torch.cat((schemaCattn, false_schemaCattn),0),
                                      torch.cat((torch.ones(730,), torch.zeros(730,)),0))
  dataset_schemaCattn = torch.utils.data.random_split(dataset_schemaCattn, [0.9, 0.1])
  schemaCattntrain = DataLoader(dataset_schemaCattn[0], batch_size, shuffle=True)
  schemaCattnval = DataLoader(dataset_schemaCattn[1], batch_size, shuffle=True)

  fitschema(modelBschema, schemaCattntrain, schemaCattnval, "BCschema_schemaattn", n_epochs=800, policy_only=True)
  fitcontrol(modelBcontrol, schemaCattntrain, schemaCattnval, "BCcontrol_schemaattn", n_epochs=800)

  evaluate(modelBschema, schemaCattnval, tname+"BCschema_schemaattn", save_attn=False)
  evaluate(modelBcontrol, schemaCattnval, tname+"BCcontrol_schemaattn", save_attn=False)

  seedn += 1
  torch.manual_seed(seeds[seedn])

  del(schemaCattn)
  del(false_schemaCattn)
  del(indices)
  del(dataset_schemaCattn)
  del(schemaCattntrain)
  del(schemaCattnval)

  # CONTROL
  del(modelAschema)
  del(modelAcontrol)
  del(modelBschema)
  del(modelBcontrol)
  del(modelCschema)
  del(modelCcontrol)

  modelCschema = vision_transformer_clspred.VitAttentionSchema().to(device)
  modelCcontrol = vision_transformer_clspred.VitControl().to(device)
  modelCschema.load_state_dict(modelCschema_wts)
  modelCcontrol.load_state_dict(modelCcontrol_wts)
  freeze_models([modelCschema, modelCcontrol])

  controlAattn = torch.load("/content/drive/My Drive/networkattention/data/attentions/CLStrial0Acontrolattn.pt")
  controlAattn = torch.nn.functional.interpolate(controlAattn[:,:,0,:], size=(65536), mode="nearest")
  controlAattn = controlAattn.reshape(730,3,256,256)

  false_controlAattn = torch.clone(controlAattn)
  indices = torch.randperm(false_controlAattn.shape[-1])
  false_controlAattn = false_controlAattn[:,:,indices] # false attention values are shuffled along last dimension

  dataset_controlAattn = TensorDataset(torch.cat((controlAattn, false_controlAattn),0),
                                      torch.cat((torch.ones(730,), torch.zeros(730,)),0))
  dataset_controlAattn = torch.utils.data.random_split(dataset_controlAattn, [0.9, 0.1])
  controlAattntrain = DataLoader(dataset_controlAattn[0], batch_size, shuffle=True)
  controlAattnval = DataLoader(dataset_controlAattn[1], batch_size, shuffle=True)

  fitschema(modelCschema, controlAattntrain, controlAattnval, tname+"CAs_controlattn", n_epochs=800, policy_only=True)
  fitcontrol(modelCcontrol, controlAattntrain, controlAattnval, tname+"CAc_controlattn", n_epochs=800)

  evaluate(modelCschema, controlAattnval, tname+"CAschema_controlattn", save_attn=False)
  evaluate(modelCcontrol, controlAattnval, tname+"CAcontrol_controlattn", save_attn=False)

  seedn += 1
  torch.manual_seed(seeds[seedn])

  del(controlAattn)
  del(false_controlAattn)
  del(indices)
  del(dataset_controlAattn)
  del(controlAattntrain)
  del(controlAattnval)
  del(modelCschema)
  del(modelCcontrol)
  del(modelCschema_wts)
  del(modelCcontrol_wts)

  modelAschema = vision_transformer_clspred.VitAttentionSchema().to(device)
  modelAcontrol = vision_transformer_clspred.VitControl().to(device)
  modelAschema.load_state_dict(modelAschema_wts)
  modelAcontrol.load_state_dict(modelAcontrol_wts)
  freeze_models([modelAschema, modelAcontrol])

  controlBattn = torch.load("/content/drive/My Drive/networkattention/data/attentions/CLStrial0Bcontrolattn.pt")
  controlBattn = torch.nn.functional.interpolate(controlBattn[:,:,0,:], size=(65536), mode="nearest")
  controlBattn = controlBattn.reshape(730,3,256,256)

  false_controlBattn = torch.clone(controlBattn)
  indices = torch.randperm(false_controlBattn.shape[-1])
  false_controlBattn = false_controlBattn[:,:,indices]

  dataset_controlBattn = TensorDataset(torch.cat((controlBattn, false_controlBattn),0),
                                      torch.cat((torch.ones(730,), torch.zeros(730,)),0))
  dataset_controlBattn = torch.utils.data.random_split(dataset_controlBattn, [0.9, 0.1])
  controlBattntrain = DataLoader(dataset_controlBattn[0], batch_size, shuffle=True)
  controlBattnval = DataLoader(dataset_controlBattn[1], batch_size, shuffle=True)

  fitschema(modelAschema, controlBattntrain, controlBattnval, tname+"ABs_controlattn", n_epochs=800, policy_only=True)
  fitcontrol(modelAcontrol, controlBattntrain, controlBattnval, tname+"ABc_controlattn", n_epochs=800)

  evaluate(modelAschema, controlBattnval, tname+"ABschema_controlattn", save_attn=False)
  evaluate(modelAcontrol, controlBattnval, tname+"ABcontrol_controlattn", save_attn=False)

  seedn += 1
  torch.manual_seed(seeds[seedn])

  del(controlBattn)
  del(false_controlBattn)
  del(indices)
  del(dataset_controlBattn)
  del(controlBattntrain)
  del(controlBattnval)
  del(modelAschema)
  del(modelAcontrol)
  del(modelAschema_wts)
  del(modelAcontrol_wts)

  modelBschema = vision_transformer_clspred.VitAttentionSchema().to(device)
  modelBcontrol = vision_transformer_clspred.VitControl().to(device)
  modelBschema.load_state_dict(modelBschema_wts)
  modelBcontrol.load_state_dict(modelBcontrol_wts)
  freeze_models([modelBschema, modelBcontrol])

  controlCattn = torch.load("/content/drive/My Drive/networkattention/data/attentions/CLStrial0Ccontrolattn.pt")
  controlCattn = torch.nn.functional.interpolate(controlCattn[:,:,0,:], size=(65536), mode="nearest")
  controlCattn = controlCattn.reshape(730,3,256,256)

  false_controlCattn = torch.clone(controlCattn)
  indices = torch.randperm(false_controlCattn.shape[-1])
  false_controlCattn = false_controlCattn[:,:,indices]

  dataset_controlCattn = TensorDataset(torch.cat((controlCattn, false_controlCattn),0),
                                      torch.cat((torch.ones(730,), torch.zeros(730,)),0))
  dataset_controlCattn = torch.utils.data.random_split(dataset_controlCattn, [0.9, 0.1])
  controlCattntrain = DataLoader(dataset_controlCattn[0], batch_size, shuffle=True)
  controlCattnval = DataLoader(dataset_controlCattn[1], batch_size, shuffle=True)

  fitschema(modelBschema, controlCattntrain, controlCattnval, tname+"BCs_controlattn", n_epochs=800, policy_only=True)
  fitcontrol(modelBcontrol, controlCattntrain, controlCattnval, tname+"BCc_controlattn", n_epochs=800)

  evaluate(modelBschema, controlCattnval, tname+"BCschema_controlattn", save_attn=False)
  evaluate(modelBcontrol, controlCattnval, tname+"BCcontrol_controlattn", save_attn=False)

  seedn += 1
  torch.manual_seed(seeds[seedn])

  del(controlCattn)
  del(false_controlCattn)
  del(indices)
  del(dataset_controlCattn)
  del(controlCattntrain)
  del(controlCattnval)
  del(modelBschema)
  del(modelBcontrol)
  del(modelBschema_wts)
  del(modelBcontrol_wts)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
130: 0.9570909738540649 / 165: 2.6774544858571248
131: 1.9722931385040283 / 165: 2.689407777605634
132: 0.7331352829933167 / 165: 2.6938510217449876
133: 3.8234598636627197 / 165: 2.717023505767186
134: 3.7971339225769043 / 165: 2.7400364386312885
135: 4.894780158996582 / 165: 2.7697017729282374
136: 3.2752065658569336 / 165: 2.7895515096910066
137: 2.9079184532165527 / 165: 2.807175257892319
138: 0.884784460067749 / 165: 2.8125375879533356
139: 0.880483865737915 / 165: 2.817873853806293
140: 4.430624961853027 / 165: 2.8447261263023718
141: 1.7367322444915771 / 165: 2.8552517762689873
142: 3.6735332012176514 / 165: 2.8775156138521245
143: 2.766575813293457 / 165: 2.894282739993297
144: 2.3524439334869385 / 165: 2.9085399759538237
145: 2.461306571960449 / 165: 2.9234569854808568
146: 1.906010627746582 / 165: 2.935008565042957
147: 2.284855604171753 / 165: 2.94885617476521
148: 3.498631715774536 / 165: 2.9700600033456617
14