<a href="https://colab.research.google.com/github/hamish-haggerty/AI-hacking/blob/master/SSL/cancer_validation_ensemble.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# cancer_validation_ensemble

> Purpose of this notebook is to explore whether interspersing some BT pretraining makes an ensemble better. 

In [1]:
#| default_exp cancer_validation_ensemble

Setup: Surely there is a way to get rid of having to put this cell everywhere. hmmm.

Or we can just copy paste / delete this in and out when needed. Either way, getting close to a decent workable workflow.

In [1]:
#| hide

import os
from google.colab import drive

def colab_is_true():

    try: 
        from google.colab import drive

        return True 
    except ModuleNotFoundError:
        return False

def setup_colab():

    drive.mount('/content/drive',force_remount=True)
    #os.system('unzip -q "/content/drive/My Drive/archive (1).zip"')
    os.system('git clone https://github.com/hamish-haggerty/cancer-proj.git')

    os.chdir('cancer-proj')
    
    os.system('pip install .')
    os.system('pip install -qU nbdev')
    os.system('nbdev_install_quarto')

    os.system('unzip -q "/content/drive/My Drive/archive (1).zip"') #does this work?

if __name__ == "__main__":
    on_colab = colab_is_true()
    if on_colab:
        setup_colab()

Mounted at /content/drive


In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
from fastai.vision.all import *
from base_rbt.all import *
from cancer_proj.cancer_dataloading import *
from cancer_proj.cancer_metrics import *
from cancer_proj.cancer_maintrain import *

In [4]:
@torch.no_grad()
def predict_model(xval,yval,model,aug_pipelines_test,numavg=3):
    "Note that this assumes xval is entire validation set. If it doesn't fit in memory, can't use this guy"
    
    model.eval()

    N=xval.shape[0]

    scores=0
    for _ in range(numavg):

        scores += model(aug_pipelines_test[0](xval)) #test time augmentation. This also gets around issue of randomness in the dataloader in each session...

    scores *= 1/numavg

    ypred = cast(torch.argmax(scores, dim=1),TensorCategory)

    correct = (ypred == yval)#.type(torch.FloatTensor)

    #correct = (torch.argmax(ypred,dim=1) == yval).type(torch.FloatTensor)
    num_correct = correct.sum()
    accuracy = num_correct/N
    
    return scores,ypred,accuracy.item()

## Load the data

In [5]:
#| hide

#Since we have cloned repository and cd'd into it (and the data itself is not stored in the
#repo) we need cd out of it, get the data, then cd back into the repo `cancer-proj`.
#This is a bit annoying, can maybe remove this later
if on_colab:
    #os.chdir('..') #assumes we are currently in cancer-proj directory
    train_dir = colab_train_dir
    test_dir = colab_test_dir
else:
    train_dir = local_train_dir
    test_dir = local_test_dir

#define general hps
device ='cuda' if torch.cuda.is_available() else 'cpu'
#bs=256
#bs=698
bs=256
bs_tune=256
size=128
bs_val=174

#get the data dictionary
data_dict = get_fnames_dls_dict(train_dir=train_dir,test_dir=test_dir,
                    device=device,bs_val=bs_val,bs=bs,bs_tune=bs_tune,size=size,n_in=3)

#get the dataloaders
dls_train,dls_tune,dls_valid = data_dict['dls_train'],data_dict['dls_tune'],data_dict['dls_valid']
x,y = data_dict['x'],data_dict['y']
xval,yval = data_dict['xval'],data_dict['yval']
xtune,ytune = data_dict['xtune'],data_dict['ytune']
vocab = data_dict['vocab']

#If we want to write some tests (make sure the data is same every time etc):
fnames,fnames_train,fnames_tune,fnames_valid,fnames_test = data_dict['fnames'],data_dict['fnames_train'],data_dict['fnames_tune'],data_dict['fnames_valid'],data_dict['fnames_test']

test_eq(x.shape,xtune.shape)

# if on_colab:
#     os.chdir('cancer-proj')

## Load aug pipelines here

In [6]:
#| hide

aug_dict = create_aug_pipelines(size=size,device=device,Augs=BYOL_Augs,TUNE_Augs=TUNE_Augs,Val_Augs=Val_Augs)
aug_pipelines = aug_dict['aug_pipelines']
aug_pipelines_tune = aug_dict['aug_pipelines_tune']
aug_pipelines_test = aug_dict['aug_pipelines_test'] 

## Optionally, display:

In [9]:
#| hide
#show_bt_batch(dls=dls_train,aug=aug_pipelines,n_in=3)

In [10]:
#| hide

#show_linear_batch(dls=dls_tune,n_in=3,aug=aug_pipelines_tune,n=2,print_augs=True)

In [7]:
#| export

@patch
def lf(self:BarlowTwins, pred,*yb): return lf_bt(pred,I=self.I,lmb=self.lmb)

Need to run a few exploratory experiments. Based on the results, next is to run some systematic experiments, probably with W and B... Or final results...

In [8]:
#| export

@patch
@delegates(Learner.fit_one_cycle)
def encoder_fine_tune(self:Learner, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,
              pct_start=0.3, div=5.0, **kwargs):
    "Fine tuner to use with bt initial weights"
    
    self.freeze() #freeze the resnet
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    base_lr /= 2
    #self.unfreeze() #don't unfreeze the resnet. We are fitting training the encoder head + projector
    #self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)
    self.fit_one_cycle(epochs, slice(base_lr, base_lr), pct_start=pct_start, div=div, **kwargs)

    self.unfreeze() #We can unfreeze at the end

## Exploratory experiment: BT initial weights, with a small amount of pretraining. First, let's try updating all of the weights (i.e. the resnet gets updated with BT pretraining). Remember, we need to freeze the pretrained resnet first, and align the encoder-head + projector head.

# We need to edit several of our base functions: Since we have to align the head of the encoder with the projector, we need to edit `create_model`, and define a new bt_splitter: i.e. the splitter needs to freeze the pretrained resnet, and leave the new head_encoder + projector unfrozen.

In [9]:
#| export

class HeadEncoder(nn.Module):
    "Basic nonlinear "
    def __init__(self,resnet_encoder,device='cuda'):
        super().__init__()

        self.resnet_encoder=resnet_encoder

        self.head_encoder = sequential(nn.Linear(2048,2048),nn.BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                               nn.ReLU(inplace=True))
        
        self.device = torch.device(device)
        self.to(self.device)


    def forward(self,x):
        x=self.resnet_encoder(x)
        x=self.head_encoder(x)

        return x

def create_model(which_model,device,ps=8192,n_in=3):
    print('inside create_model')

    #pretrained=True if 'which_model' in ['bt_pretrain', 'supervised_pretrain'] else False

    if which_model == 'bt_pretrain': model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
    
    elif which_model == 'no_pretrain': model = resnet50()

    elif which_model == 'supervised_pretrain': model = resnet50(weights='IMAGENET1K_V2')

    #ignore the 'pretrained=False' argument here. Just means we use the weights above 
    #(which themselves are either pretrained or not)
    encoder = get_resnet_encoder(model)
    encoder = HeadEncoder(encoder,device='cpu')

    model = create_barlow_twins_model(encoder, hidden_size=ps,projection_size=ps,nlayers=3)

    if device == 'cuda':
        model.cuda()
        encoder.cuda()


    return model,encoder

bt_model,encoder = create_model(which_model='bt_pretrain',ps=8192,device=device)

def my_splitter_bt(m):

    return L(sequential(*m.encoder.resnet_encoder),sequential(m.encoder.head_encoder,m.projector)).map(params)

test_eq(len(my_splitter_bt(bt_model)),2)

inside create_model


Downloading: "https://github.com/facebookresearch/barlowtwins/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/barlowtwins/ep1000_bs2048_lrw0.2_lrb0.0048_lambd0.0051/resnet50.pth" to /root/.cache/torch/hub/checkpoints/resnet50.pth


  0%|          | 0.00/90.0M [00:00<?, ?B/s]

In [14]:
# #Verify that splitter freezes expected part of model:

# #test : manual. BT

learn = Learner(dls_train,bt_model,splitter=my_splitter_bt,cbs=[BarlowTwins(aug_pipelines,n_in=3,lmb=1/8192,print_augs=False)])
learn.freeze()
print('resnet should be frozen, encoder head + projector unfrozen')
learn.summary()


resnet should be frozen, encoder head + projector unfrozen


BarlowTwinsModel (Input shape: 256 x 3 x 128 x 128)
Layer (type)         Output Shape         Param #    Trainable 
                     256 x 64 x 64 x 64  
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     256 x 64 x 32 x 32  
MaxPool2d                                                      
Conv2d                                    4096       False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
____________________________________________________________________________
                     256 x 256 x 32 x 32 
Conv2d                                    16384      False     
BatchNorm2d                 

## We also need to edit main_train. Now the model is: resnet_encoder -> head_encoder -> projector or linear layer. Also need to edit the splitter fuction for fine tuning.

All that changes here is the definition of our model in `fine_tune`, and we need a new splitter function.

We tried just patching in new def of fine_tune, but since `create_model` defintion changed it fucked things up. 

Feels like this should be some kind of callback, written extensibly enough that you just just "patch" functions etc in...

In [10]:
#| export

class main_train:
    """Instantiate and (optionally) train the encoder. Then fine-tune the supervised model. 
    Outputs metrics on validation data"""

    def __init__(self,
                 dls_train, #used for training BT (if pretrain=True)
                 dls_tune , #used for tuning
                 dls_valid, #used to compute metrics / evaluate results. 
                 xval, #currently `predict_model` below assumes this is entire validation / test data
                 yval,
                 aug_pipelines, #the aug pipeline for self-supervised learning
                 aug_pipelines_tune, #the aug pipeline for supervised learning
                 aug_pipelines_test, #test (or valid) time augmentations 
                 initial_weights, #Which initial weights to use
                 pretrain, #Whether to fit BT
                 num_epochs, #number of BT fit epochs
                 numfit, #number of tune_fit epochs
                 freeze_num_epochs, #How many epochs to freeze body for when training BT
                 freeze_numfit, #How many epochs to freeze body for when fine tuning
                 ps=8192, #projection size
                 n_in=3, #color channels
                 indim=2048, #dimension output of encoder (2048 for resnet50)
                 outdim=9, #number of classes
                 print_report=False, #F1 metrics etc
                 print_plot=False, #ROC curve
                 ):
        store_attr()
        self.vocab = self.dls_valid.vocab
        self.device = 'cuda' if torch.cuda.is_available else 'cpu'

                
                 

                 #Soon we might want to save some models here:

                 #if self.model_type == 'res_proj': test_eq(self.fit_policy,'resnet_fine_tune') #I THINK this is only viable option?
                 #self.encoder_path = f'/content/drive/My Drive/models/baselineencoder_initial_weights={self.initial_weights}_pretrain={self.pretrain}.pth'
                 #self.tuned_model_path = f'/content/drive/My Drive/models/baselinefinetuned_initial_weights={self.initial_weights}_pretrain={self.pretrain}.pth'

    @staticmethod
    def fit(learn,fit_type,epochs,freeze_epochs,initial_weights):
        """We can patch in a modification, e.g. if we want subtype of fine_tune:supervised_pretrain to be different
        to fine_tune:bt_pretrain"""

        if fit_type == 'encoder_fine_tune': #i.e. barlow twins

            learn.encoder_fine_tune(epochs,freeze_epochs=freeze_epochs) 

        elif fit_type == 'fine_tune':
            
            #elif initial_weights == 'supervised_pretrain':
            learn.linear_fine_tune(epochs,freeze_epochs=freeze_epochs) 

        else: raise Exception('Fit policy not of expected form')

    def train_encoder(self):
        "create encoder and (optionally, if pretrain=True) train with BT algorithm, according to fit_policy"

        try: #get existing encoder and plonk on new projector
            encoder = self.encoder
            encoder.cpu()
            bt_model = create_barlow_twins_model(encoder, hidden_size=self.ps,projection_size=self.ps,nlayers=3)
            bt_model.cuda()

        except AttributeError: #otherwise, create
            bt_model,encoder = create_model(which_model=self.initial_weights,ps=self.ps,device=self.device)

        if self.pretrain: #train encoder according to fit policy

            learn = Learner(self.dls_train,bt_model,splitter=my_splitter_bt,cbs=[BarlowTwins(self.aug_pipelines,n_in=self.n_in,lmb=1/self.ps,print_augs=False)])
            main_train.fit(learn,fit_type='encoder_fine_tune',
                           epochs=self.num_epochs,freeze_epochs=self.freeze_num_epochs,
                           initial_weights=self.initial_weights
                          )
            
        self.encoder = bt_model.encoder

    def fine_tune(self):
        "fine tune in supervised fashion, according to tune_fit_policy, and get metrics"

        #encoder = pickle.loads(pickle.dumps(self.encoder)) #We might want to pretrain once and fine tune several times (varying e.g. tune augs)

        try: 
            encoder = self.encoder
        
        except AttributeError:
            _,self.encoder = create_model(which_model=self.initial_weights,ps=self.ps,device=device)

        #model = LM(self.encoder)
        model = sequential(self.encoder,nn.Linear(2048,9))
        
        learn = Learner(self.dls_tune,model,splitter=my_splitter,cbs = [LinearBt(aug_pipelines=self.aug_pipelines_tune,n_in=self.n_in)],wd=0.0)

        #debugging
        #learn = Learner(self.dls_tune,model,cbs = [LinearBt(aug_pipelines=self.aug_pipelines_tune,n_in=self.n_in)],wd=0.0)

        main_train.fit(learn,fit_type='fine_tune',
                       epochs=self.numfit,freeze_epochs=self.freeze_numfit,
                       initial_weights=self.initial_weights
                      ) #fine tuning (don't confuse this with fit policy!)
        
        #model.encoder=encoder
        scores,preds, acc = predict_model(self.xval,self.yval,model=model,aug_pipelines_test=self.aug_pipelines_test,numavg=3)
        #metrics dict will have f1 score, auc etc etc
        metrics = classification_report_wrapper(preds, self.yval, self.vocab, print_report=self.print_report)
        auc_dict = plot_roc(self.yval,preds,self.vocab,print_plot=self.print_plot)
        metrics['acc'],metrics['auc_dict'],metrics['scores'],metrics['preds'],metrics['xval'],metrics['yval'] = acc,auc_dict,scores,preds,self.xval,self.yval
  
        #torch.save(model.state_dict(), self.tuned_model_path)
        return metrics #

    def __call__(self):

        self.train_encoder() #train (or extract) the encoder
        metrics = self.fine_tune()
        
        return metrics



We need to define the splitter function for the fine_tune part of main differently as well:

In [11]:
def my_splitter(m):
    print('inside new my_splitter')
    return L(sequential(*m[0].resnet_encoder),sequential(m[0].head_encoder,m[1])).map(params)

In [31]:
# # #Verify that splitter freezes expected part of model, from linear point of view:

bt_model,encoder = create_model(which_model='bt_pretrain',ps=8192,device=device)
model = sequential(encoder,nn.Linear(2048,9))
test_eq(len(my_splitter(model)),2)
test_eq(len(my_splitter_bt(bt_model)),2)

learn = Learner(dls_tune,model,splitter=my_splitter,cbs = [LinearBt(aug_pipelines=aug_pipelines_tune,n_in=3)],wd=0.0)
learn.freeze()
print('resnet should be frozen, encoder_head + linear layer unfrozen')
learn.summary()


inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


inside new my_splitter
inside new my_splitter
resnet should be frozen, encoder_head + linear layer unfrozen


Sequential (Input shape: 256 x 3 x 128 x 128)
Layer (type)         Output Shape         Param #    Trainable 
                     256 x 64 x 64 x 64  
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     256 x 64 x 32 x 32  
MaxPool2d                                                      
Conv2d                                    4096       False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
____________________________________________________________________________
                     256 x 256 x 32 x 32 
Conv2d                                    16384      False     
BatchNorm2d                       

## First we need to verify that the head still gives good performance:

In [18]:
#Non default inputs
# initial_weights = 'supervised_pretrain'
# pretrain=False
# numfit=50
# num_epochs='na'
# freeze_num_epochs = 'na'
# freeze_numfit=3

# main = main_train(dls_train=dls_train,dls_tune=dls_tune,dls_valid=dls_valid, xval=xval, yval=yval,
#         aug_pipelines=aug_pipelines, aug_pipelines_tune=aug_pipelines_tune, aug_pipelines_test=aug_pipelines_test, 
#         initial_weights=initial_weights,pretrain=pretrain,
#         num_epochs=num_epochs,numfit=numfit,freeze_num_epochs=freeze_num_epochs,freeze_numfit=freeze_numfit,
#         print_report=True,
#                  )

# metrics = main()

Ok, great.

## Check for BT as well:

This hasn't been working. Hypothesis: need to freeze backbone for longer.

In [19]:
#Non default inputs
initial_weights = 'bt_pretrain'
pretrain=False
numfit=50
num_epochs='na'
freeze_num_epochs = 'na'
freeze_numfit=6

main = main_train(dls_train=dls_train,dls_tune=dls_tune,dls_valid=dls_valid, xval=xval, yval=yval,
        aug_pipelines=aug_pipelines, aug_pipelines_tune=aug_pipelines_tune, aug_pipelines_test=aug_pipelines_test, 
        initial_weights=initial_weights,pretrain=pretrain,
        num_epochs=num_epochs,numfit=numfit,freeze_num_epochs=freeze_num_epochs,freeze_numfit=freeze_numfit,
        print_report=True,
                 )

metrics = main()

inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


inside new my_splitter


epoch,train_loss,valid_loss,time
0,2.226256,,00:06
1,2.07304,,00:06
2,1.875264,,00:06
3,1.674538,,00:06
4,1.505282,,00:06
5,1.353871,,00:06


  warn("Your generator is empty.")


epoch,train_loss,valid_loss,time
0,0.597909,,00:06
1,0.593243,,00:06
2,0.556044,,00:06
3,0.509185,,00:06
4,0.493918,,00:06
5,0.461221,,00:06
6,0.438027,,00:06
7,0.417411,,00:06
8,0.402489,,00:06
9,0.381574,,00:06


                            precision    recall  f1-score   support

         actinic keratosis       0.59      0.65      0.62        20
      basal cell carcinoma       0.61      0.70      0.65        20
            dermatofibroma       0.80      0.84      0.82        19
                  melanoma       0.33      0.25      0.29        20
                     nevus       0.38      0.50      0.43        20
pigmented benign keratosis       0.62      0.40      0.48        20
      seborrheic keratosis       0.47      0.47      0.47        15
   squamous cell carcinoma       0.53      0.50      0.51        20
           vascular lesion       0.86      0.90      0.88        20

                  accuracy                           0.58       174
                 macro avg       0.58      0.58      0.57       174
              weighted avg       0.58      0.58      0.57       174



Try freezing the resnet for longer

In [37]:
#Non default inputs
initial_weights = 'bt_pretrain'
pretrain=False
numfit=30
num_epochs='na'
freeze_num_epochs = 'na'
freeze_numfit=20

main = main_train(dls_train=dls_train,dls_tune=dls_tune,dls_valid=dls_valid, xval=xval, yval=yval,
        aug_pipelines=aug_pipelines, aug_pipelines_tune=aug_pipelines_tune, aug_pipelines_test=aug_pipelines_test, 
        initial_weights=initial_weights,pretrain=pretrain,
        num_epochs=num_epochs,numfit=numfit,freeze_num_epochs=freeze_num_epochs,freeze_numfit=freeze_numfit,
        print_report=True,
                 )

metrics = main()

inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


inside new my_splitter


epoch,train_loss,valid_loss,time
0,2.231801,,00:06
1,2.133175,,00:06
2,2.024523,,00:06
3,1.909593,,00:07
4,1.797583,,00:06
5,1.687701,,00:06
6,1.578967,,00:06
7,1.478047,,00:06
8,1.384773,,00:06
9,1.301351,,00:06


  warn("Your generator is empty.")


epoch,train_loss,valid_loss,time
0,0.287917,,00:06
1,0.275527,,00:06
2,0.244412,,00:07
3,0.236292,,00:06
4,0.236868,,00:06
5,0.229989,,00:06
6,0.222742,,00:06
7,0.212479,,00:06
8,0.203615,,00:06
9,0.205114,,00:06


                            precision    recall  f1-score   support

         actinic keratosis       0.50      0.60      0.55        20
      basal cell carcinoma       0.68      0.65      0.67        20
            dermatofibroma       0.74      0.74      0.74        19
                  melanoma       0.28      0.25      0.26        20
                     nevus       0.53      0.50      0.51        20
pigmented benign keratosis       0.50      0.55      0.52        20
      seborrheic keratosis       0.25      0.20      0.22        15
   squamous cell carcinoma       0.50      0.45      0.47        20
           vascular lesion       0.83      0.95      0.88        20

                  accuracy                           0.55       174
                 macro avg       0.53      0.54      0.54       174
              weighted avg       0.54      0.55      0.54       174



Wow! Still really bad performance. Adding a nonlinear head seems to really harm BT, so far. All I can think of is pretraining it really well.

## It appears the nonlinear head is harming BT performance. Weird. 

In [20]:
assert False

AssertionError: ignored

## Ok, let's just try adding linear layer to BT backbone

All we do is uncomment one line in `create_model`, and edited the splitter:

In [22]:
#| export


def create_model(which_model,device,ps=8192,n_in=3):
    print('inside create_model')

    #pretrained=True if 'which_model' in ['bt_pretrain', 'supervised_pretrain'] else False

    if which_model == 'bt_pretrain': model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
    
    elif which_model == 'no_pretrain': model = resnet50()

    elif which_model == 'supervised_pretrain': model = resnet50(weights='IMAGENET1K_V2')

    #ignore the 'pretrained=False' argument here. Just means we use the weights above 
    #(which themselves are either pretrained or not)
    encoder = get_resnet_encoder(model)
    #encoder = HeadEncoder(encoder,device='cpu')

    model = create_barlow_twins_model(encoder, hidden_size=ps,projection_size=ps,nlayers=3)

    if device == 'cuda':
        model.cuda()
        encoder.cuda()


    return model,encoder


def my_splitter(m):
    print('inside new my_splitter')
    return L(sequential(*m[0]),sequential(m[1])).map(params)

In [24]:
# # #Verify that splitter freezes expected part of model, from linear point of view:

bt_model,encoder = create_model(which_model='bt_pretrain',ps=8192,device=device)
model = sequential(encoder,nn.Linear(2048,9))
learn = Learner(dls_tune,model,splitter=my_splitter,cbs = [LinearBt(aug_pipelines=aug_pipelines_tune,n_in=3)],wd=0.0)
learn.freeze()
print('resnet should be frozen, then should just have unfrozen linear layer')
learn.summary()


inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


inside new my_splitter
resnet should be frozen, encoder_head + linear layer unfrozen


Sequential (Input shape: 256 x 3 x 128 x 128)
Layer (type)         Output Shape         Param #    Trainable 
                     256 x 64 x 64 x 64  
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     256 x 64 x 32 x 32  
MaxPool2d                                                      
Conv2d                                    4096       False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
____________________________________________________________________________
                     256 x 256 x 32 x 32 
Conv2d                                    16384      False     
BatchNorm2d                       

##Experiment: Linear layer. Comparing to with nonlinear head...

In [26]:
#Non default inputs
initial_weights = 'bt_pretrain'
pretrain=False
numfit=50
num_epochs='na'
freeze_num_epochs = 'na'
freeze_numfit=6

avg=0
for _ in range(3):

    main = main_train(dls_train=dls_train,dls_tune=dls_tune,dls_valid=dls_valid, xval=xval, yval=yval,
            aug_pipelines=aug_pipelines, aug_pipelines_tune=aug_pipelines_tune, aug_pipelines_test=aug_pipelines_test, 
            initial_weights=initial_weights,pretrain=pretrain,
            num_epochs=num_epochs,numfit=numfit,freeze_num_epochs=freeze_num_epochs,freeze_numfit=freeze_numfit,
            print_report=True,
                    )

    metrics = main()
    avg += metrics['acc']

avg/3



inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


inside new my_splitter


epoch,train_loss,valid_loss,time
0,2.198341,,00:06
1,2.196352,,00:06
2,2.192322,,00:06
3,2.186077,,00:06
4,2.176804,,00:06
5,2.163485,,00:06


  warn("Your generator is empty.")


epoch,train_loss,valid_loss,time
0,2.06943,,00:06
1,2.068809,,00:06
2,2.059758,,00:06
3,2.053328,,00:06
4,2.044665,,00:06
5,2.0346,,00:06
6,2.023101,,00:06
7,2.011719,,00:06
8,1.996386,,00:06
9,1.976805,,00:06


                            precision    recall  f1-score   support

         actinic keratosis       0.64      0.70      0.67        20
      basal cell carcinoma       0.68      0.65      0.67        20
            dermatofibroma       0.80      0.84      0.82        19
                  melanoma       0.50      0.40      0.44        20
                     nevus       0.59      0.65      0.62        20
pigmented benign keratosis       0.56      0.70      0.62        20
      seborrheic keratosis       0.50      0.40      0.44        15
   squamous cell carcinoma       0.71      0.50      0.59        20
           vascular lesion       0.79      0.95      0.86        20

                  accuracy                           0.65       174
                 macro avg       0.64      0.64      0.64       174
              weighted avg       0.65      0.65      0.64       174

inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


inside new my_splitter


epoch,train_loss,valid_loss,time
0,2.202211,,00:07
1,2.199894,,00:06
2,2.196281,,00:06
3,2.189226,,00:07
4,2.179641,,00:06
5,2.165967,,00:06


  warn("Your generator is empty.")


epoch,train_loss,valid_loss,time
0,2.065496,,00:06
1,2.065332,,00:06
2,2.058458,,00:06
3,2.051941,,00:06
4,2.04421,,00:06
5,2.03434,,00:06
6,2.02471,,00:07
7,2.011851,,00:06
8,1.997306,,00:06
9,1.978611,,00:06


                            precision    recall  f1-score   support

         actinic keratosis       0.70      0.70      0.70        20
      basal cell carcinoma       0.69      0.55      0.61        20
            dermatofibroma       0.80      0.84      0.82        19
                  melanoma       0.44      0.35      0.39        20
                     nevus       0.55      0.55      0.55        20
pigmented benign keratosis       0.48      0.50      0.49        20
      seborrheic keratosis       0.47      0.53      0.50        15
   squamous cell carcinoma       0.63      0.60      0.62        20
           vascular lesion       0.76      0.95      0.84        20

                  accuracy                           0.62       174
                 macro avg       0.61      0.62      0.61       174
              weighted avg       0.62      0.62      0.62       174

inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


inside new my_splitter


epoch,train_loss,valid_loss,time
0,2.197782,,00:06
1,2.196333,,00:06
2,2.192182,,00:06
3,2.186058,,00:06
4,2.176118,,00:07
5,2.162399,,00:06


  warn("Your generator is empty.")


epoch,train_loss,valid_loss,time
0,2.0691,,00:06
1,2.061539,,00:06
2,2.052295,,00:06
3,2.048429,,00:07
4,2.04067,,00:07
5,2.032983,,00:06
6,2.023321,,00:06
7,2.009928,,00:06
8,1.99508,,00:06
9,1.974495,,00:07


                            precision    recall  f1-score   support

         actinic keratosis       0.59      0.65      0.62        20
      basal cell carcinoma       0.83      0.75      0.79        20
            dermatofibroma       0.84      0.84      0.84        19
                  melanoma       0.57      0.40      0.47        20
                     nevus       0.58      0.70      0.64        20
pigmented benign keratosis       0.48      0.60      0.53        20
      seborrheic keratosis       0.60      0.60      0.60        15
   squamous cell carcinoma       0.65      0.55      0.59        20
           vascular lesion       0.90      0.90      0.90        20

                  accuracy                           0.67       174
                 macro avg       0.67      0.67      0.67       174
              weighted avg       0.67      0.67      0.67       174



0.6455938617388407

So just adding linear layer: ~ 0.65 accuracy, vs. ~ 0.6 for training a nonlinear head. 

Ok, maybe. But might need to compare to just linear head later.

## Warning: we added new stuff above... i.e. linear probe. 

## Exploratory baseline ensembling with a nonlinear head, without pretraining:

First, we need to make sure the nonlinear head model is still around...

In [12]:
#| export

#| export

class HeadEncoder(nn.Module):
    "Basic nonlinear "
    def __init__(self,resnet_encoder,device='cuda'):
        super().__init__()

        self.resnet_encoder=resnet_encoder

        self.head_encoder = sequential(nn.Linear(2048,2048),nn.BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                               nn.ReLU(inplace=True))
        
        self.device = torch.device(device)
        self.to(self.device)


    def forward(self,x):
        x=self.resnet_encoder(x)
        x=self.head_encoder(x)

        return x

def create_model(which_model,device,ps=8192,n_in=3):
    print('inside create_model')

    #pretrained=True if 'which_model' in ['bt_pretrain', 'supervised_pretrain'] else False

    if which_model == 'bt_pretrain': model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
    
    elif which_model == 'no_pretrain': model = resnet50()

    elif which_model == 'supervised_pretrain': model = resnet50(weights='IMAGENET1K_V2')

    #ignore the 'pretrained=False' argument here. Just means we use the weights above 
    #(which themselves are either pretrained or not)
    encoder = get_resnet_encoder(model)
    encoder = HeadEncoder(encoder,device='cpu')

    model = create_barlow_twins_model(encoder, hidden_size=ps,projection_size=ps,nlayers=3)

    if device == 'cuda':
        model.cuda()
        encoder.cuda()


    return model,encoder

bt_model,encoder = create_model(which_model='bt_pretrain',ps=8192,device=device)

def my_splitter_bt(m):

    return L(sequential(*m.encoder.resnet_encoder),sequential(m.encoder.head_encoder,m.projector)).map(params)

test_eq(len(my_splitter_bt(bt_model)),2)

class main_train:
    """Instantiate and (optionally) train the encoder. Then fine-tune the supervised model. 
    Outputs metrics on validation data"""

    def __init__(self,
                 dls_train, #used for training BT (if pretrain=True)
                 dls_tune , #used for tuning
                 dls_valid, #used to compute metrics / evaluate results. 
                 xval, #currently `predict_model` below assumes this is entire validation / test data
                 yval,
                 aug_pipelines, #the aug pipeline for self-supervised learning
                 aug_pipelines_tune, #the aug pipeline for supervised learning
                 aug_pipelines_test, #test (or valid) time augmentations 
                 initial_weights, #Which initial weights to use
                 pretrain, #Whether to fit BT
                 num_epochs, #number of BT fit epochs
                 numfit, #number of tune_fit epochs
                 freeze_num_epochs, #How many epochs to freeze body for when training BT
                 freeze_numfit, #How many epochs to freeze body for when fine tuning
                 ps=8192, #projection size
                 n_in=3, #color channels
                 indim=2048, #dimension output of encoder (2048 for resnet50)
                 outdim=9, #number of classes
                 print_report=False, #F1 metrics etc
                 print_plot=False, #ROC curve
                 ):
        store_attr()
        self.vocab = self.dls_valid.vocab
        self.device = 'cuda' if torch.cuda.is_available else 'cpu'

                
                 

                 #Soon we might want to save some models here:

                 #if self.model_type == 'res_proj': test_eq(self.fit_policy,'resnet_fine_tune') #I THINK this is only viable option?
                 #self.encoder_path = f'/content/drive/My Drive/models/baselineencoder_initial_weights={self.initial_weights}_pretrain={self.pretrain}.pth'
                 #self.tuned_model_path = f'/content/drive/My Drive/models/baselinefinetuned_initial_weights={self.initial_weights}_pretrain={self.pretrain}.pth'

    @staticmethod
    def fit(learn,fit_type,epochs,freeze_epochs,initial_weights):
        """We can patch in a modification, e.g. if we want subtype of fine_tune:supervised_pretrain to be different
        to fine_tune:bt_pretrain"""

        if fit_type == 'encoder_fine_tune': #i.e. barlow twins

            learn.encoder_fine_tune(epochs,freeze_epochs=freeze_epochs) 

        elif fit_type == 'fine_tune':
            
            #elif initial_weights == 'supervised_pretrain':
            learn.linear_fine_tune(epochs,freeze_epochs=freeze_epochs) 

        else: raise Exception('Fit policy not of expected form')

    def train_encoder(self):
        "create encoder and (optionally, if pretrain=True) train with BT algorithm, according to fit_policy"

        try: #get existing encoder and plonk on new projector
            encoder = self.encoder
            encoder.cpu()
            bt_model = create_barlow_twins_model(encoder, hidden_size=self.ps,projection_size=self.ps,nlayers=3)
            bt_model.cuda()

        except AttributeError: #otherwise, create
            bt_model,encoder = create_model(which_model=self.initial_weights,ps=self.ps,device=self.device)

        if self.pretrain: #train encoder according to fit policy

            learn = Learner(self.dls_train,bt_model,splitter=my_splitter_bt,cbs=[BarlowTwins(self.aug_pipelines,n_in=self.n_in,lmb=1/self.ps,print_augs=False)])
            main_train.fit(learn,fit_type='encoder_fine_tune',
                           epochs=self.num_epochs,freeze_epochs=self.freeze_num_epochs,
                           initial_weights=self.initial_weights
                          )
            
        self.encoder = bt_model.encoder

    def fine_tune(self):
        "fine tune in supervised fashion, according to tune_fit_policy, and get metrics"

        #encoder = pickle.loads(pickle.dumps(self.encoder)) #We might want to pretrain once and fine tune several times (varying e.g. tune augs)

        try: 
            encoder = self.encoder
        
        except AttributeError:
            _,self.encoder = create_model(which_model=self.initial_weights,ps=self.ps,device=device)

        #model = LM(self.encoder)
        model = sequential(self.encoder,nn.Linear(2048,9))
        
        learn = Learner(self.dls_tune,model,splitter=my_splitter,cbs = [LinearBt(aug_pipelines=self.aug_pipelines_tune,n_in=self.n_in)],wd=0.0)

        #debugging
        #learn = Learner(self.dls_tune,model,cbs = [LinearBt(aug_pipelines=self.aug_pipelines_tune,n_in=self.n_in)],wd=0.0)

        main_train.fit(learn,fit_type='fine_tune',
                       epochs=self.numfit,freeze_epochs=self.freeze_numfit,
                       initial_weights=self.initial_weights
                      ) #fine tuning (don't confuse this with fit policy!)
        
        #model.encoder=encoder
        scores,preds, acc = predict_model(self.xval,self.yval,model=model,aug_pipelines_test=self.aug_pipelines_test,numavg=3)
        #metrics dict will have f1 score, auc etc etc
        metrics = classification_report_wrapper(preds, self.yval, self.vocab, print_report=self.print_report)
        auc_dict = plot_roc(self.yval,preds,self.vocab,print_plot=self.print_plot)
        metrics['acc'],metrics['auc_dict'],metrics['scores'],metrics['preds'],metrics['xval'],metrics['yval'] = acc,auc_dict,scores,preds,self.xval,self.yval
  
        #torch.save(model.state_dict(), self.tuned_model_path)
        return metrics #

    def __call__(self):

        self.train_encoder() #train (or extract) the encoder
        metrics = self.fine_tune()
        
        return metrics



inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


In [13]:
# #Verify that splitter freezes expected part of model:

# #test : manual. BT

learn = Learner(dls_train,bt_model,splitter=my_splitter_bt,cbs=[BarlowTwins(aug_pipelines,n_in=3,lmb=1/8192,print_augs=False)])
learn.freeze()
print('resnet should be frozen, encoder head + projector unfrozen')
learn.summary()


resnet should be frozen, encoder head + projector unfrozen


BarlowTwinsModel (Input shape: 256 x 3 x 128 x 128)
Layer (type)         Output Shape         Param #    Trainable 
                     256 x 64 x 64 x 64  
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     256 x 64 x 32 x 32  
MaxPool2d                                                      
Conv2d                                    4096       False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
____________________________________________________________________________
                     256 x 256 x 32 x 32 
Conv2d                                    16384      False     
BatchNorm2d                 

In [14]:
# # #Verify that splitter freezes expected part of model, from linear point of view:

bt_model,encoder = create_model(which_model='bt_pretrain',ps=8192,device=device)
model = sequential(encoder,nn.Linear(2048,9))
learn = Learner(dls_tune,model,splitter=my_splitter,cbs = [LinearBt(aug_pipelines=aug_pipelines_tune,n_in=3)],wd=0.0)
learn.freeze()
print('resnet should be frozen, then should just have unfrozen linear layer')
learn.summary()


inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


inside new my_splitter
resnet should be frozen, then should just have unfrozen linear layer


Sequential (Input shape: 256 x 3 x 128 x 128)
Layer (type)         Output Shape         Param #    Trainable 
                     256 x 64 x 64 x 64  
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     256 x 64 x 32 x 32  
MaxPool2d                                                      
Conv2d                                    4096       False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
____________________________________________________________________________
                     256 x 256 x 32 x 32 
Conv2d                                    16384      False     
BatchNorm2d                       

#Experiment: pretrain the head for a really long time, with resnet frozen.

In [15]:
#Non default inputs
initial_weights = 'bt_pretrain'
pretrain=True
numfit=50
num_epochs=200
freeze_num_epochs = 1
freeze_numfit=6


main = main_train(dls_train=dls_train,dls_tune=dls_tune,dls_valid=dls_valid, xval=xval, yval=yval,
        aug_pipelines=aug_pipelines, aug_pipelines_tune=aug_pipelines_tune, aug_pipelines_test=aug_pipelines_test, 
        initial_weights=initial_weights,pretrain=pretrain,
        num_epochs=num_epochs,numfit=numfit,freeze_num_epochs=freeze_num_epochs,freeze_numfit=freeze_numfit,
        print_report=True,
                )

metrics = main()
print(metrics['acc'])


inside create_model


Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


epoch,train_loss,valid_loss,time
0,5618.505859,,00:07


  warn("Your generator is empty.")


epoch,train_loss,valid_loss,time
0,2225.252686,,00:07
1,2220.608398,,00:06
2,2020.624756,,00:06
3,1898.112427,,00:06
4,1812.766479,,00:06
5,1726.94397,,00:06
6,1662.019775,,00:06
7,1600.636841,,00:07
8,1549.371338,,00:07
9,1498.420288,,00:06


inside new my_splitter


epoch,train_loss,valid_loss,time
0,2.209545,,00:07
1,2.074716,,00:07
2,1.903065,,00:06
3,1.718835,,00:06
4,1.545133,,00:06
5,1.394172,,00:06


epoch,train_loss,valid_loss,time
0,0.575437,,00:07
1,0.549418,,00:06
2,0.523801,,00:06
3,0.496486,,00:06
4,0.469012,,00:07
5,0.446456,,00:06
6,0.421263,,00:07
7,0.398532,,00:06
8,0.371554,,00:06
9,0.352936,,00:07


                            precision    recall  f1-score   support

         actinic keratosis       0.57      0.65      0.60        20
      basal cell carcinoma       0.50      0.50      0.50        20
            dermatofibroma       0.74      0.74      0.74        19
                  melanoma       0.32      0.30      0.31        20
                     nevus       0.48      0.50      0.49        20
pigmented benign keratosis       0.46      0.55      0.50        20
      seborrheic keratosis       0.30      0.20      0.24        15
   squamous cell carcinoma       0.56      0.45      0.50        20
           vascular lesion       0.82      0.90      0.86        20

                  accuracy                           0.54       174
                 macro avg       0.53      0.53      0.53       174
              weighted avg       0.53      0.54      0.53       174

0.540229856967926


## Wow, it didn't work at all.

TODO:

Depending on above results, we need to:

- Explore the effect of using `fine_tune` instead of `linear_fine_tune`. i.e. using different learning rates.
- Perhaps try using a nice FastAI head, via `create_head` magic. Perhaps we can pretrain this guy with BT instead?
- Run head-probing experiments, where head may be linear or non linear. That is, freeze the resnet the whole way. Not sure whether to bother whether this. Probably not actually. Paper sugges LP-FT works best, after all.


Just look at bt weights first:

In [41]:
def run_main_train(initial_weights,num_epochs,freeze_numfit,freeze_num_epochs,numfit=100,pretrain=False,num=5):
    "run main_train num times."

    main_dict = {}
    for i in range(num):

        main = main_train(dls_train=dls_train,dls_tune=dls_tune,dls_valid=dls_valid, xval=xval, yval=yval,
                aug_pipelines=aug_pipelines, aug_pipelines_tune=aug_pipelines_tune, aug_pipelines_test=aug_pipelines_test, 
                initial_weights=initial_weights,pretrain=pretrain,
                num_epochs=num_epochs,numfit=numfit,freeze_num_epochs=freeze_num_epochs,freeze_numfit=freeze_numfit,
                print_report=True,
                        )
        
        metrics = main()
        main_dict[i] = metrics

    return main_dict
        

## Experiment: Baseline, non linear head with no pretraining. Ensemble.

In [None]:
initial_weights='bt_pretrain'
main_dict = run_main_train(initial_weights=initial_weights,numfit=50,num_epochs='na',freeze_num_epochs='na',pretrain=False,freeze_numfit=3,num=3)

from itertools import combinations
from statistics import mean
from statistics import stdev

print('Results for ensembling within bt weights:')

bt_results = list(main_dict.values())

_bt_results = [k['acc'] for k in bt_results]
print(_bt_results)
print(mean(_bt_results))
print(stdev(_bt_results))

bt_results = list(combinations(bt_results,2)) #all pairs of results. So for num=3, will be 3
for v in bt_results:

    print(f"\nAcc of first guy in ensemble is: {v[0]['acc']}")
    print(f"Acc of second guy in ensemble is: {v[1]['acc']}")
    _,acc = predict_ensemble(yval=yval,scores1=v[0]['scores'],scores2=v[1]['scores'])
    print(f'Acc of ensemble is:{acc}\n')

In [None]:
initial_weights='bt_pretrain'
main_dict = run_main_train(initial_weights=initial_weights,numfit=30,num_epochs='na',freeze_num_epochs='na',pretrain=False,freeze_numfit=3,num=3)

from itertools import combinations
from statistics import mean
from statistics import stdev

print('Results for ensembling within bt weights:')

bt_results = list(main_dict.values())

_bt_results = [k['acc'] for k in bt_results]
print(_bt_results)
print(mean(_bt_results))
print(stdev(_bt_results))

bt_results = list(combinations(bt_results,2)) #all pairs of results. So for num=3, will be 3
for v in bt_results:

    print(f"\nAcc of first guy in ensemble is: {v[0]['acc']}")
    print(f"Acc of second guy in ensemble is: {v[1]['acc']}")
    _,acc = predict_ensemble(yval=yval,scores1=v[0]['scores'],scores2=v[1]['scores'])
    print(f'Acc of ensemble is:{acc}\n')

In [None]:
assert False

Now, let's look at doing some pretraining. For now, say BT pretrain for 10 epochs. Freezing the projector doesn't make sense remember: we are aligning the random head and projector. But, check above: the backbone encoder is frozen. 

Thinking about this more: training the encoder_head on a frozen backbone will make them less variable.

First, let's try an ensemble with the resnet frozen the whole way through.

Notice that the resnet is kept frozen the whole way through (you can look up above to verify the freeze is working as expected):

## Experiment: pretrain the nonlinear head with BT

In [None]:
#| export
@patch
@delegates(Learner.fit_one_cycle)
def encoder_fine_tune(self:Learner, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,
              pct_start=0.3, div=5.0, **kwargs):
    "Fine tuner to use with bt initial weights"
    
    self.freeze() #freeze the resnet
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    base_lr /= 2
    #self.unfreeze() #don't unfreeze the resnet. We are fitting training the encoder head + projector
    #self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)
    self.fit_one_cycle(epochs, slice(base_lr, base_lr), pct_start=pct_start, div=div, **kwargs)

    self.unfreeze() #We can unfreeze at the end


if __name__ == "__main__":
    initial_weights='bt_pretrain'
    num_epochs=100
    freeze_num_epochs=1
    freeze_numfit=3
    pretrain=True
    main_dict = run_main_train(initial_weights,num_epochs=50,freeze_numfit,freeze_num_epochs,pretrain=pretrain,num=3)



    print('Results for ensembling with bt weights, where we just trained the head')
    bt_results = list(main_dict.values())
    print([bt_results[i]['acc'] for i in range(len(bt_results))])
    bt_results = list(combinations(bt_results,2)) #all pairs of results. So for num=3, will be 3
    for v in bt_results:

        print(f"\nAcc of first guy in ensemble is: {v[0]['acc']}")
        print(f"Acc of second guy in ensemble is: {v[1]['acc']}")
        _,acc = predict_ensemble(yval=yval,scores1=v[0]['scores'],scores2=v[1]['scores'])
        print(f'Acc of ensemble is:{acc}\n')
        


A natural thing to try is to just plonk a projector on the end, train BT as usual (unfrozen encoder). Then plonk a random head on the end and fine tune as usual. The idea here is we will create some variation in the BT weights without destroying them: should increase ensemble performance slightly. Note that training just random heads with BT (as above) will DECREASE the variation in the heads: they have gone from random, to all trained on the same objective. 


To do this, we will need to edit everything again (gah!)

Main points in edit in below cell(s): 

- create_model is as original: resnet + projector
- We need to check bt_splitter and verify is working in cell below
- We need to check the linear splitter as well...
- We need to include encoder_fine_tune, to work as before.
- fine_tune now needs a random encoder_head + linear layer

In [None]:
print('We ran two experiments above: basline with head (no pretraining) and with BT pretraining just the head')

## We don't want to run any of this stuff:

In [None]:
assert False

In [None]:
#| export

def create_model(which_model,device,ps=8192,n_in=3):
    print('inside create_model')

    #pretrained=True if 'which_model' in ['bt_pretrain', 'supervised_pretrain'] else False

    if which_model == 'bt_pretrain': model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
    
    elif which_model == 'no_pretrain': model = resnet50()

    elif which_model == 'supervised_pretrain': model = resnet50(weights='IMAGENET1K_V2')

    #ignore the 'pretrained=False' argument here. Just means we use the weights above 
    #(which themselves are either pretrained or not)
    encoder = get_resnet_encoder(model)
    #encoder = HeadEncoder(encoder,device='cpu')

    model = create_barlow_twins_model(encoder, hidden_size=ps,projection_size=ps,nlayers=3)

    if device == 'cuda':
        model.cuda()
        encoder.cuda()


    return model,encoder

class main_train:
    """Instantiate and (optionally) train the encoder. Then fine-tune the supervised model. 
    Outputs metrics on validation data"""

    def __init__(self,
                 dls_train, #used for training BT (if pretrain=True)
                 dls_tune , #used for tuning
                 dls_valid, #used to compute metrics / evaluate results. 
                 xval, #currently `predict_model` below assumes this is entire validation / test data
                 yval,
                 aug_pipelines, #the aug pipeline for self-supervised learning
                 aug_pipelines_tune, #the aug pipeline for supervised learning
                 aug_pipelines_test, #test (or valid) time augmentations 
                 initial_weights, #Which initial weights to use
                 pretrain, #Whether to fit BT
                 num_epochs, #number of BT fit epochs
                 numfit, #number of tune_fit epochs
                 freeze_num_epochs, #How many epochs to freeze body for when training BT
                 freeze_numfit, #How many epochs to freeze body for when fine tuning
                 ps=8192, #projection size
                 n_in=3, #color channels
                 indim=2048, #dimension output of encoder (2048 for resnet50)
                 outdim=9, #number of classes
                 print_report=False, #F1 metrics etc
                 print_plot=False, #ROC curve
                 ):
        store_attr()
        self.vocab = self.dls_valid.vocab
        self.device = 'cuda' if torch.cuda.is_available else 'cpu'

                
                 

                 #Soon we might want to save some models here:

                 #if self.model_type == 'res_proj': test_eq(self.fit_policy,'resnet_fine_tune') #I THINK this is only viable option?
                 #self.encoder_path = f'/content/drive/My Drive/models/baselineencoder_initial_weights={self.initial_weights}_pretrain={self.pretrain}.pth'
                 #self.tuned_model_path = f'/content/drive/My Drive/models/baselinefinetuned_initial_weights={self.initial_weights}_pretrain={self.pretrain}.pth'

    @staticmethod
    def fit(learn,fit_type,epochs,freeze_epochs,initial_weights):
        """We can patch in a modification, e.g. if we want subtype of fine_tune:supervised_pretrain to be different
        to fine_tune:bt_pretrain"""

        if fit_type == 'encoder_fine_tune': #i.e. barlow twins

            learn.encoder_fine_tune(epochs,freeze_epochs=freeze_epochs) 

        elif fit_type == 'fine_tune':
            
            #elif initial_weights == 'supervised_pretrain':
            learn.linear_fine_tune(epochs,freeze_epochs=freeze_epochs) 

        else: raise Exception('Fit policy not of expected form')

    def train_encoder(self):
        "create encoder and (optionally, if pretrain=True) train with BT algorithm, according to fit_policy"

        try: #get existing encoder and plonk on new projector
            encoder = self.encoder
            encoder.cpu()
            bt_model = create_barlow_twins_model(encoder, hidden_size=self.ps,projection_size=self.ps,nlayers=3)
            bt_model.cuda()

        except AttributeError: #otherwise, create
            bt_model,encoder = create_model(which_model=self.initial_weights,ps=self.ps,device=self.device)

        if self.pretrain: #train encoder according to fit policy

            learn = Learner(self.dls_train,bt_model,splitter=my_splitter_bt,cbs=[BarlowTwins(self.aug_pipelines,n_in=self.n_in,lmb=1/self.ps,print_augs=False)])
            main_train.fit(learn,fit_type='encoder_fine_tune',
                           epochs=self.num_epochs,freeze_epochs=self.freeze_num_epochs,
                           initial_weights=self.initial_weights
                          )
            
        self.encoder = bt_model.encoder

    def fine_tune(self):
        "fine tune in supervised fashion, according to tune_fit_policy, and get metrics"

        #encoder = pickle.loads(pickle.dumps(self.encoder)) #We might want to pretrain once and fine tune several times (varying e.g. tune augs)

        try: 
            encoder = self.encoder
        
        except AttributeError:
            _,self.encoder = create_model(which_model=self.initial_weights,ps=self.ps,device=device)

        #model = LM(self.encoder)
        encoder = HeadEncoder(self.encoder,device='cuda') #resnet + nonlinear head
        model = sequential(encoder,nn.Linear(2048,9)) #+ linear layer. 
        
        learn = Learner(self.dls_tune,model,splitter=my_splitter,cbs = [LinearBt(aug_pipelines=self.aug_pipelines_tune,n_in=self.n_in)],wd=0.0)

        #debugging
        #learn = Learner(self.dls_tune,model,cbs = [LinearBt(aug_pipelines=self.aug_pipelines_tune,n_in=self.n_in)],wd=0.0)

        main_train.fit(learn,fit_type='fine_tune',
                       epochs=self.numfit,freeze_epochs=self.freeze_numfit,
                       initial_weights=self.initial_weights
                      ) #fine tuning (don't confuse this with fit policy!)
        
        #model.encoder=encoder
        scores,preds, acc = predict_model(self.xval,self.yval,model=model,aug_pipelines_test=self.aug_pipelines_test,numavg=3)
        #metrics dict will have f1 score, auc etc etc
        metrics = classification_report_wrapper(preds, self.yval, self.vocab, print_report=self.print_report)
        auc_dict = plot_roc(self.yval,preds,self.vocab,print_plot=self.print_plot)
        metrics['acc'],metrics['auc_dict'],metrics['scores'],metrics['preds'],metrics['xval'],metrics['yval'] = acc,auc_dict,scores,preds,self.xval,self.yval
  
        #torch.save(model.state_dict(), self.tuned_model_path)
        return metrics #

    def __call__(self):

        self.train_encoder() #train (or extract) the encoder
        metrics = self.fine_tune()
        
        return metrics

#The model is now resnet->nonlinear head -> linear layer
def my_splitter(m):

    return L(sequential(*m[0].resnet_encoder),sequential(m[0].head_encoder,m[1])).map(params)


#The model is now just a resnet encoder + a projector
def my_splitter_bt(m):

    return L(sequential(*m.encoder),sequential(m.projector)).map(params)



In [None]:
#| export

@patch
@delegates(Learner.fit_one_cycle)
def encoder_fine_tune(self:Learner, epochs, base_lr=2e-3, freeze_epochs=1, lr_mult=100,
              pct_start=0.3, div=5.0, **kwargs):
    "Fine tuner to use with bt initial weights"
    
    self.freeze() #freeze the resnet
    print('froze resnet')
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    base_lr /= 2
    self.unfreeze() #Now we want to unfreeze the resnet!
    print('unfroze resnet')
    #self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)
    self.fit_one_cycle(epochs, slice(base_lr, base_lr), pct_start=pct_start, div=div, **kwargs)

    self.unfreeze() #We can unfreeze at the end

In [None]:
#test bt split
bt_model,encoder = create_model(which_model='bt_pretrain',ps=8192,device=device)

#| hide

#test : manual. BT

learn = Learner(dls_train,bt_model,splitter=my_splitter_bt,cbs=[BarlowTwins(aug_pipelines,n_in=3,lmb=1/8192,print_augs=False)])
learn.freeze()
print('resnet (frozen) + projector')
learn.summary()

In [None]:
#test linear split
bt_model,encoder = create_model(which_model='bt_pretrain',ps=8192,device=device)
encoder = HeadEncoder(encoder,device='cuda') #resnet + nonlinear head
model = sequential(encoder,nn.Linear(2048,9)) #+ linear layer. 
model.cuda()

learn = Learner(dls_tune,model,splitter=my_splitter,cbs = [LinearBt(aug_pipelines=aug_pipelines_tune,n_in=3)],wd=0.0)
learn.freeze()
print('resnet (frozen) + unfrozen head and linear layer')
learn.summary()

Go back over your checklist!

##Alright, now our thesis is that pretraining with BT, so long as we don't do it for too long and destroy the representations, will cause diversity and improve ensembling:

(as an aside, we could add a callback/model that implements our ensembling idea IN PROJECTOR SPACE. So the idea is, basically, to push the projectors apart (while the resnet is frozen) as while as aligning it with the resnet, then just train as usual. Make sense!

In [None]:
initial_weights='bt_pretrain'
num_epochs=10
freeze_num_epochs=10
freeze_numfit=3
pretrain=True

main_dict = run_main_train(initial_weights,num_epochs,freeze_numfit,freeze_num_epochs,pretrain=pretrain,num=3)

In [None]:
from itertools import combinations

print('Results for ensembling with bt weights, where we trained the usual way (freeze resnet, then unfreeeze)')
bt_results = list(main_dict.values())
print([bt_results[i]['acc'] for i in range(len(bt_results))])
bt_results = list(combinations(bt_results,2)) #all pairs of results. So for num=3, will be 3
for v in bt_results:

    print(f"\nAcc of first guy in ensemble is: {v[0]['acc']}")
    print(f"Acc of second guy in ensemble is: {v[1]['acc']}")
    _,acc = predict_ensemble(yval=yval,scores1=v[0]['scores'],scores2=v[1]['scores'])
    print(f'Acc of ensemble is:{acc}\n')
    

## Ok, didn't do much. Seems performance is (at least potentially) similar to before. So, a natural next thing to try is to just do the same experiment, but for longer. Also, it makes sense to use a lower base learning rate. Also, it would have been better to make things more extensible...

## In this experiment we edited encoder_fine_tune (essentially lowered the learning rate), and trained for larger number of epochs.

##Lesson: perhaps should have made main_train even more extensible: all hyperparameters it depends on should be passable (even as e.g. dictionaries). Anyway:

In [None]:

@patch
@delegates(Learner.fit_one_cycle)
def encoder_fine_tune(self:Learner, epochs, base_lr=1e-3, freeze_epochs=1, lr_mult=100,
              pct_start=0.3, div=5.0, **kwargs):
    "Fine tuner to use with bt initial weights"
    
    self.freeze() #freeze the resnet
    print('froze resnet')
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    base_lr /= 2
    self.unfreeze() #Now we want to unfreeze the resnet!
    print('unfroze resnet')
    self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div=div, **kwargs)
    #self.fit_one_cycle(epochs, slice(base_lr, base_lr), pct_start=pct_start, div=div, **kwargs)

    self.unfreeze() #We can unfreeze at the end


initial_weights='bt_pretrain'
num_epochs=50
freeze_num_epochs=10
freeze_numfit=3
pretrain=True

main_dict = run_main_train(initial_weights,num_epochs,freeze_numfit,freeze_num_epochs,pretrain=pretrain,num=3)


from itertools import combinations

print('Results for ensembling with bt weights, where we trained the usual way (freeze resnet, then unfreeeze)')
bt_results = list(main_dict.values())
print([bt_results[i]['acc'] for i in range(len(bt_results))])
bt_results = list(combinations(bt_results,2)) #all pairs of results. So for num=3, will be 3
for v in bt_results:

    print(f"\nAcc of first guy in ensemble is: {v[0]['acc']}")
    print(f"Acc of second guy in ensemble is: {v[1]['acc']}")
    _,acc = predict_ensemble(yval=yval,scores1=v[0]['scores'],scores2=v[1]['scores'])
    print(f'Acc of ensemble is:{acc}\n')
    

## An alternative approach could be to just train head with BT, but add a decorrelation penalty term: this really only probably makes sense if pretraining the head really does help (which we really think it should).