In [None]:
!git clone https://github.com/evgerritz/calligraphy_data.git

Cloning into 'calligraphy_data'...
remote: Enumerating objects: 105091, done.[K
remote: Counting objects: 100% (43/43), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 105091 (delta 1), reused 43 (delta 1), pack-reused 105048[K
Receiving objects: 100% (105091/105091), 167.90 MiB | 37.01 MiB/s, done.
Resolving deltas: 100% (502/502), done.
Updating files: 100% (105081/105081), done.


In [None]:

!pip install datasets
!pip install accelerate -U

Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dill, multiprocess, datasets
Successfully installed datasets-2.16.1 dill-0.3.7 multiprocess-0.70.15
Collecting accelerate
  Downloading accelerate-0.26.1-py3-none-any.whl (270 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m270.9/270.9 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
Installing collected pack

In [None]:
import pickle
from google.colab import files
import os
import torch
import numpy as np
from glob import glob
from PIL import Image
from datasets import Dataset, load_metric, load_dataset
import transformers
from transformers import AutoImageProcessor, ViTImageProcessor, Trainer, TrainingArguments, TrainerCallback, \
    ResNetForImageClassification, ViTForImageClassification, SwinForImageClassification, PvtForImageClassification, CvtForImageClassification, PoolFormerForImageClassification, ConvNextV2ForImageClassification
from sklearn.metrics import normalized_mutual_info_score
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA, KernelPCA
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
from matplotlib import rcParams
from torchvision import transforms
import shutil



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
rcParams['pdf.fonttype'] = 42
rcParams['ps.fonttype'] = 42

try:
    os.mkdir('data')
except FileExistsError:
    pass

try:
    os.mkdir('Trainers')
except FileExistsError:
    pass

In [None]:
# Test browser download
with open('test.pkl','wb') as f:
  pickle.dump([1,2,3],f)
files.download('test.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Calligraphy Dataset

In [None]:
### SEEN CLASSES
np.random.seed(42)
downscale = 1
#1) grab all img paths
num_classes = 15
max_n_train, max_n_test = 9100, 2500
train_IDs = np.random.choice(glob(f'calligraphy_data/train_{num_classes}/*/*.jpg'), max_n_train, False)
print('Example train_id:',train_IDs[0])
test_IDs = np.random.choice(glob(f'calligraphy_data/test_{num_classes}/*/*.jpg'), max_n_test, False)
# partition = {'train':train_IDs, 'validation':test_IDs}
print(len(train_IDs),len(test_IDs))

class_names = sorted(set([tid.split('/')[2] for tid in test_IDs]))
class_name_to_ix = {cname:ci for ci,cname in enumerate(class_names)}
ix_to_class_name = {ci:cname for ci,cname in enumerate(class_names)}
print(class_name_to_ix)

# assign numeric label to each training and test sample
train_labels = [class_name_to_ix[tid.split('/')[2]] for tid in train_IDs]
test_labels = [class_name_to_ix[tid.split('/')[2]] for tid in test_IDs]

Example train_id: calligraphy_data/train_15/lgq/0424.jpg
9100 2500
{'bdsr': 0, 'fwq': 1, 'gj': 2, 'htj': 3, 'lgq': 4, 'lqs': 5, 'lx': 6, 'mzd': 7, 'oyx': 8, 'sgt': 9, 'smh': 10, 'wxz': 11, 'yyr': 12, 'yzq': 13, 'zmf': 14}


In [None]:
### UNSEEN CLASSES
all_classes = os.listdir('calligraphy_data/train_20')
unseen_names = sorted(set(all_classes).difference(class_names))
unseen_name_to_ix = {cname:ci for ci,cname in enumerate(unseen_names)}
ix_to_unseen_name = {ci:cname for ci,cname in enumerate(unseen_names)}
print(unseen_names)
max_n_per_unseen = 500//downscale
unseen_IDs = []
unseen_labels = []

for i,uc in enumerate(unseen_names):
  ids = glob(f'calligraphy_data/train_20/{uc}/*.jpg')[:max_n_per_unseen]
  unseen_IDs += ids
  unseen_labels += [i]*len(ids)
len(unseen_IDs)

['csl', 'hy', 'mf', 'shz', 'wzm']


2500

In [None]:
def train_gen():
  for i,img_path in enumerate(train_IDs):
    yield {'image':Image.open(img_path), 'labels':train_labels[i]}

def test_gen():
  for i,img_path in enumerate(test_IDs):
    yield {'image':Image.open(img_path), 'labels':test_labels[i]}

def unseen_gen():
  for i,img_path in enumerate(unseen_IDs):
    yield {'image':Image.open(img_path), 'labels':unseen_labels[i]}


#https://stackoverflow.com/questions/76001128/splitting-dataset-into-train-test-and-validation-using-huggingface-datasets-fun
ds_train = Dataset.from_generator(train_gen)#.train_test_split(test_size=0.1)
ds_test = Dataset.from_generator(test_gen)
ds_unseen = Dataset.from_generator(unseen_gen)

ds_train.shape, ds_test.shape, ds_unseen.shape

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

((9100, 2), (2500, 2), (2500, 2))

Cifar-100

In [None]:
# ds_train_cf, ds_test_cf = load_dataset('cifar100', split=['train[:50000]', 'test[:10000]'])
# ds_train_cf = ds_train_cf.rename_column("fine_label", "labels")
# ds_test_cf = ds_test_cf.rename_column("fine_label", "labels")
# ds_train_cf = ds_train_cf.rename_column("img", "image")
# ds_test_cf = ds_test_cf.rename_column("img", "image")

# id2label_cf = {id:label for id, label in enumerate(ds_train_cf.features['labels'].names)}
# label2id_cf = {label:id for id,label in id2label_cf.items()}

In [None]:
# unseen_names_cf = ['orchid', 'poppy', 'rose', 'sunflower', 'tulip']
# unseen_ids_cf = [label2id_cf[x] for x in unseen_names_cf]

# ds_unseen_cf = ds_train_cf.filter(lambda x: x['labels'] in unseen_ids_cf)
# ds_train_cf = ds_train_cf.filter(lambda x: x['labels'] not in unseen_ids_cf and x['labels'] < 20)
# ds_test_cf = ds_test_cf.filter(lambda x: x['labels'] not in unseen_ids_cf and x['labels'] < 20)

# class_names_cf = [name for name in ds_train_cf.features['labels'].names if name not in unseen_names_cf]

# old_indices = list(np.unique(ds_train_cf['labels']))
# def new_class_index(x):
#     x['labels'] = int(old_indices.index(x['labels']))
#     return x
# ds_train_cf = ds_train_cf.map(new_class_index)
# ds_test_cf = ds_test_cf.map(new_class_index)

# unseen_labels_cf = ds_unseen_cf['labels']
# test_labels_cf = ds_test_cf['labels']

Fine-tuning

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }


metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

  metric = load_metric("accuracy")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

In [None]:
class FTModel:
    def __init__(self, model_name_or_path, nickname, nhs, processor_class, basemodel, model_type, callig=True, n_batches=4, batch_size=72):
        self.modelname = nickname
        self.datasetname = 'Calligraphy' if callig else 'CIFAR-100'
        self.processor = processor_class.from_pretrained(model_name_or_path)
        self.datadir = 'data/' + self.modelname + self.datasetname
        self.nhs = nhs
        self.model_type = model_type
        self.callig = callig
        self.calligstr = 'callig' if callig else 'cifar'
        self.plot = True#callig
        self.calc_seen_nmi = True#callig
        try:
            os.mkdir(self.datadir)
        except FileExistsError:
            pass

        size_key = 'height' if 'height' in self.processor.size else 'shortest_edge'
        size = self.processor.size[size_key]

        image_mean = self.processor.image_mean
        image_std = self.processor.image_std

        my_transforms = transforms.Compose([
            transforms.Resize(size), #resize first to avoid jaggedy edges after rotation
            transforms.Pad(20,255), #pad before rotating to avoid dark borders
            transforms.RandomRotation(degrees=10),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize(mean=image_mean, std=image_std),
        ])

        def transform(example_batch):
            inputs = {}#processor([x for x in example_batch['image']], return_tensors='pt')
            inputs['pixel_values'] = [my_transforms(image) for image in example_batch['image']]

            # Don't forget to include the labels!
            inputs['labels'] = example_batch['labels']
            return inputs

        if callig:
            self.prepared_ds_train = ds_train.with_transform(transform)
            self.prepared_ds_test = ds_test.with_transform(transform)
            self.prepared_ds_unseen = ds_unseen.with_transform(transform)
            self.data_class_names = class_names
        else:
            # cifar data
            self.prepared_ds_train = ds_train_cf.with_transform(transform)
            self.prepared_ds_test = ds_test_cf.with_transform(transform)
            self.prepared_ds_unseen = ds_unseen_cf.with_transform(transform)
            self.data_class_names = class_names_cf

        self.model = basemodel.from_pretrained(
            model_name_or_path,
            num_labels=len(self.data_class_names),
            id2label={str(i): c for i, c in enumerate(self.data_class_names)},
            label2id={c: str(i) for i, c in enumerate(self.data_class_names)},
            ignore_mismatched_sizes = True,
        )
        self.model = self.model.to(device)

        self.BATCH_SIZE = batch_size
        self.training_args = TrainingArguments(
            output_dir=f"./Trainers/{self.modelname}_{self.calligstr}Trainer",
            per_device_train_batch_size=self.BATCH_SIZE,
            evaluation_strategy="steps",
            num_train_epochs=n_batches,
            fp16=True if torch.cuda.is_available() else False,
            save_steps=100,
            eval_steps=100,
            logging_steps=10,
            learning_rate=2e-4,
            save_total_limit=2,
            remove_unused_columns=False,
            push_to_hub=False,
            report_to='tensorboard',
            load_best_model_at_end=True,
        )

        self.trainer = Trainer(
            model=self.model,
            args=self.training_args,
            data_collator=collate_fn,
            compute_metrics=compute_metrics,
            train_dataset=self.prepared_ds_train,
            eval_dataset=self.prepared_ds_test,
            tokenizer=self.processor,
        )

    def get_embed_pkl_fname(self, step, unseen_str):
        return f'{self.datadir}/saved_embeds_step{step if step is not None else ""}_{unseen_str}.pkl'

    def get_embed_plot_fname(self, step, unseen_str):
        return f'{self.datadir}/{step}_{unseen_str}.pdf'

    def get_dir(self):
        return f'{self.datadir}/'

def newResNet(callig=True, **kwargs):
    return FTModel(
        model_name_or_path = 'microsoft/resnet-50',
        nickname = 'ResNet',
        nhs = 5,
        processor_class = AutoImageProcessor,
        basemodel = ResNetForImageClassification,
        model_type = 'CNN',
        callig = callig,
        **kwargs
    )

def newViT(callig=True, **kwargs):
    return FTModel(
        model_name_or_path = 'google/vit-base-patch16-224-in21k',
        nickname = 'ViT',
        nhs = 13,
        processor_class = ViTImageProcessor,
        basemodel = ViTForImageClassification,
        model_type = 'Transformer',
        callig = callig,
        **kwargs,
    )

def newSwin(callig=True, **kwargs):
    return FTModel(
        model_name_or_path = 'microsoft/swin-tiny-patch4-window7-224',
        nickname = 'Swin',
        nhs = 5,
        processor_class = AutoImageProcessor,
        basemodel = SwinForImageClassification,
        model_type = 'Transformer',
        callig = callig,
        **kwargs
    )

def newPViT(callig=True, **kwargs):
    return FTModel(
        model_name_or_path = 'Zetatech/pvt-tiny-224',
        nickname = 'PViT',
        nhs = 9,
        processor_class = AutoImageProcessor,
        basemodel = PvtForImageClassification,
        model_type = 'Transformer',
        callig = callig,
    )

def newCvT(callig=True):
    return FTModel(
        model_name_or_path = 'microsoft/cvt-21',
        nickname = 'CvT',
        nhs = 3,
        processor_class = AutoImageProcessor,
        basemodel = CvtForImageClassification,
        model_type = 'CNN',
        callig = callig,
    )

def newPoolFormer(callig=True):
    return FTModel(
        model_name_or_path = 'sail/poolformer_s12',
        nickname = 'PoolFormer',
        nhs = 4,
        processor_class = AutoImageProcessor,
        basemodel = PoolFormerForImageClassification,
        model_type = 'CNN', #?
        callig = callig,
    )

def newConvNeXtV2(callig=True):
    return FTModel(
        model_name_or_path = "facebook/convnextv2-tiny-1k-224",
        nickname = 'ConvNeXtV2',
        nhs = 5,
        processor_class = AutoImageProcessor,
        basemodel = ConvNextV2ForImageClassification,
        model_type = 'CNN',
        callig = callig,
    )

In [None]:
# ftmodels = []
# # these run calligraphy + seen
# ResNet = newResNet(n_batches=5)
# ftmodels.append(ResNet)

# ViT = newViT()
# ftmodels.append(ViT)

# Swin = newSwin()
# ftmodels.append(Swin)

# PViT = newPViT()
# ftmodels.append(PViT)

# CvT = newCvT()
# ftmodels.append(CvT)

# PoolFormer = newPoolFormer()
# ftmodels.append(PoolFormer)

# ConvNeXtV2 = newConvNeXtV2()
# ftmodels.append(ConvNeXtV2)

############################################

# ftmodels_cf = []
# # these run calligraphy + seen
# ResNetCF = newResNet(False) #True: calligraphy; False: CIFAR
# ftmodels_cf.append(ResNetCF)

# ViTCF = newViT(False)
# ftmodels_cf.append(ViTCF)

# SwinCF = newSwin(False)
# ftmodels_cf.append(SwinCF)

# PViTCF = newPViT(False)
# ftmodels_cf.append(PViTCF)

# CvTCF = newCvT(False)
# ftmodels_cf.append(CvTCF)

# PoolFormerCF = newPoolFormer(False)
# ftmodels_cf.append(PoolFormerCF)

# ConvNeXtV2CF = newConvNeXtV2(False)
# ftmodels_cf.append(ConvNeXtV2CF)

In [None]:
def initial_eval(ftmodel):
    metrics = ftmodel.trainer.evaluate(ftmodel.prepared_ds_test)
    ftmodel.trainer.log_metrics("eval", metrics)
    ftmodel.trainer.save_metrics("eval", metrics)

#for ftmodel in ftmodels:
#    print(ftmodel.nickname)
#    initial_eval(ftmodel)

In [None]:
def retrieveAllSparseNbrs(A):
    assert A.shape[0] == len(A.indptr)-1
    for i in range(A.shape[0]):
        iptr0, iptr1 = A.indptr[i:i+2]
        yield A.indices[iptr0:iptr1]

def compute_nmis(ftmodel, which_IDs, which_ds, which_classes, which_labels, lis, modelname, batch_size,
                 embed_types = ['e0','mean'], step=None, plot=False, unseen=True):
    n = len(which_IDs)
    unseen_str = f'{"unseen" if unseen else "seen"}'

    n_batches = int(np.ceil(n/batch_size))

    embeds = {et:{} for et in embed_types}

    for bi in range(n_batches):
      print(bi,end=' ')

      inputs = torch.stack(which_ds[bi*batch_size:(bi+1)*batch_size]['pixel_values'])
      inputs = inputs.to(device)
      with torch.no_grad():
        outputs = ftmodel.model(inputs, output_hidden_states=True)
      inputs.to('cpu')
      for li in range(len(outputs['hidden_states'])):
        normed_hidden_state = outputs.hidden_states[li].detach().cpu().numpy()
        if ftmodel.model_type == 'CNN':
            if bi > 0:
                if 'e0' in embed_types:
                    embeds['e0'][li] = np.concatenate((embeds['e0'][li],normed_hidden_state.reshape(normed_hidden_state.shape[0],-1)))
                if 'mean' in embed_types:
                    embeds['mean'][li] = np.concatenate((embeds['mean'][li],normed_hidden_state.mean((2,3))))
            else:

                if 'e0' in embed_types:
                    embeds['e0'][li] = normed_hidden_state.reshape(normed_hidden_state.shape[0],-1)
                if 'mean' in embed_types:
                    embeds['mean'][li] = normed_hidden_state.mean((2,3))
        else:
            if bi > 0:
                if 'e0' in embed_types:
                    embeds['e0'][li] = np.concatenate((embeds['e0'][li],normed_hidden_state[:,0]))
                if 'mean' in embed_types:
                    embeds['mean'][li] = np.concatenate((embeds['mean'][li],normed_hidden_state.mean(1)))
            else:
                if 'e0' in embed_types:
                    embeds['e0'][li] = normed_hidden_state[:,0]
                if 'mean' in embed_types:
                    embeds['mean'][li] = normed_hidden_state.mean(1)
    print()
    pca = PCA(2)

    knn = NearestNeighbors(
        algorithm='auto',
        n_jobs=-1,
        n_neighbors=max_n_per_unseen,
    )

    dimred_methods = [('orig', lambda x: x),
                        ('pca', lambda x: pca.fit_transform(x) ),]


    km = KMeans(n_clusters=len(which_classes), random_state=42, n_init="auto")

    saved_nmis = {et:{li:{} for li in lis} for et in embed_types}

    ncols = len(embed_types)
    if plot:
      f, axes = plt.subplots(len(lis),ncols,figsize=(ncols*4,len(lis)*4))
    for i,li in enumerate(lis):
      for j,et in enumerate(embed_types):
        x = embeds[et][li]

        ### COMPUTE K-MEANS
        km.fit(x)
        saved_nmis[et][li]['NMI'] = normalized_mutual_info_score(which_labels,km.labels_)

        ### COMPUTE NNs
        knn.fit(x)
        X_NN_graph = knn.kneighbors_graph()
        NN_idxs = np.array(list(retrieveAllSparseNbrs(X_NN_graph)))
        ylabels = np.asarray(which_labels)


        NN_labels = ylabels[NN_idxs]
        pct_k1_same_class = (ylabels == NN_labels[:,0]).sum()/len(ylabels)
        saved_nmis[et][li]['NNk1'] = pct_k1_same_class
        pct_k500_same_class = np.mean((ylabels[:,None] == NN_labels).sum(1)/max_n_per_unseen)
        saved_nmis[et][li]['NNk500'] = pct_k500_same_class

        titl = f"L{li}-{et}: " + f"NMI={saved_nmis[et][li]['NMI']:.2f}, NNk1={saved_nmis[et][li]['NNk1']:.2f} " + \
                f"NNk500={saved_nmis[et][li]['NNk500']:.2f}"
        print(f'{unseen_str.upper()} {et} step {step}: {titl}')
        if plot:
          y = pca.fit_transform(x) #run pca for plotting
          with plt.style.context('seaborn-v0_8-paper'):
            ax = axes[i,j] if ncols > 1 else axes[i]
            ax.scatter(*y.T, c=which_labels, s=8, cmap='Paired', alpha=.75, label=unseen_str)

            ax.set_title(titl,size=7)
            ax.legend()

    if plot:
      f.suptitle(unseen_str)
      plt.savefig(ftmodel.get_embed_plot_fname(step, unseen_str), transparent=True, bbox_inches='tight')
      plt.close()
      #plt.show()

    pklfname = ftmodel.get_embed_pkl_fname(step, unseen_str)
    with open(pklfname,'wb') as f:
      pickle.dump(saved_nmis,f)

class CustomCallback(TrainerCallback):
    def __init__(self, ftmodel) -> None:
        super().__init__()
        self.ftmodel = ftmodel
        self._trainer = ftmodel.trainer

        self.batch_size = ftmodel.BATCH_SIZE
        self.lis = range(ftmodel.nhs)
        self.modelname = ftmodel.modelname

    def seen_and_unseen_nmis(self, types, global_step):
        if self.ftmodel.callig:
            ids = unseen_IDs
            t_ids = test_IDs
            classes = unseen_names
            unseen_labs = unseen_labels
            test_labs = test_labels
        else:
            ids = ds_unseen_cf
            t_ids = ds_test_cf
            classes = unseen_names_cf
            unseen_labs = unseen_labels_cf
            test_labs = test_labels_cf
        which_IDs, which_ds = ids, self.ftmodel.prepared_ds_unseen
        which_classes, which_labels = classes, unseen_labs
        compute_nmis(self.ftmodel, which_IDs, which_ds, which_classes, which_labels,
                     self.lis, self.modelname, self.batch_size, embed_types=types, step=global_step, plot=self.ftmodel.plot)

        if self.ftmodel.calc_seen_nmi: # This doesnt work for Cifar and I dont know why
            which_IDs, which_ds = t_ids, self.ftmodel.prepared_ds_test
            which_classes, which_labels = self.ftmodel.data_class_names, test_labs
            compute_nmis(self.ftmodel, which_IDs, which_ds, which_classes, which_labels,
                         self.lis, self.modelname, self.batch_size, embed_types=types, step=global_step, plot=self.ftmodel.plot, unseen=False)

    def on_evaluate(self, args, state, control, **kwargs):
        global_step = self._trainer.state.global_step
        if self.ftmodel.model_type == 'CNN':
            types = ['mean']
        else:
            types = ['e0']

        self.seen_and_unseen_nmis(types, global_step)

    def on_epoch_begin(self, args, state, control, **kwargs):
        global_step = self._trainer.state.global_step
        if self.ftmodel.model_type == 'CNN':
            types = ['mean']
        else:
            types = ['e0']
        if global_step > 0:
          return

        self.seen_and_unseen_nmis(types, global_step)

In [None]:
import transformers
def train_and_save(ftmodel):
    ftmodel.trainer.add_callback(CustomCallback(ftmodel))
    train_results = ftmodel.trainer.train()
    ftmodel.trainer.save_model()
    ftmodel.trainer.log_metrics("train", train_results.metrics)
    ftmodel.trainer.save_metrics("train", train_results.metrics)
    ftmodel.trainer.save_state()
    return train_results

def plot_NMI_across_layers(ftmodel):
    embs = {}
    for metric in ['NMI', 'NNk1', 'NNk500']:
      for step in [0,100,200,300,400,500]:
          with open(ftmodel.get_embed_pkl_fname(step, 'unseen'),'rb') as f:
              embs[step] = pickle.load(f)
          methd = list(embs[step].keys())[0]
          lis = sorted(list(embs[step][methd].keys()))
          ys = [embs[step][methd][li][metric] for li in lis]
          plt.plot(lis,ys,label=f'{step=}')
      plt.ylabel('NMI')
      plt.xlabel('Layer')
      plt.title(f'{ftmodel.modelname} on {ftmodel.datasetname}')
      plt.legend()
      plt.savefig(ftmodel.get_dir() + f'g_across_layers_{metric}.png', transparent=True, bbox_inches='tight')
      plt.show()

Train

In [None]:
# [ResNet, ViT, Swin, PViT, CvT, PoolFormer, ConvNeXtV2]

transformers.set_seed(42)
ResNet = newResNet(n_batches=5)
train_and_save(ResNet)
plot_NMI_across_layers(ResNet)

In [None]:
transformers.set_seed(42)
ViT = newViT()
train_and_save(ViT)
plot_NMI_across_layers(ViT)

In [None]:
transformers.set_seed(42)
Swin = newSwin()
train_and_save(Swin)
plot_NMI_across_layers(Swin)

In [None]:
transformers.set_seed(42)
PViT = newPViT()
train_and_save(PViT)
plot_NMI_across_layers(PViT)

In [None]:
transformers.set_seed(42)
CvT = newCvT()
train_and_save(CvT)
plot_NMI_across_layers(CvT)

In [None]:
transformers.set_seed(42)
PoolFormer = newPoolFormer()
train_and_save(PoolFormer)
plot_NMI_across_layers(PoolFormer)

In [None]:
transformers.set_seed(42)
ConvNeXtV2 = newConvNeXtV2()
train_and_save(ConvNeXtV2)
plot_NMI_across_layers(ConvNeXtV2)

Save results

In [None]:
# to save training accuracies:
# 1. put screenshot in downloaded file
# 2. add list of accuracies to following pickle file
accs_fname = MODEL.get_dir() + f'training_accs.pkl'
accs = [ACCSTEP0, ACCSTEP100, ...]
with open(accs_fname,'wb') as f:
  pickle.dump(accs,f)

In [None]:
!zip -q -r alldata.zip data/

In [None]:
files.download('alldata.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>