In [52]:
import os
import re

import anndata
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
from umap import UMAP
from sklearn.preprocessing import StandardScaler

In [53]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [54]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
%autoreload 2

In [56]:
from violet.utils.model import load_pretrained_model
from violet.utils.dataloaders import listfiles, image_regression_dataloaders
from violet.models.st import STRegressor, STLearner
from violet.utils.st import load_st_learner, run_st_learner, load_imagenet_st_learner

In [57]:
torch.cuda.set_device(3)
torch.cuda.current_device()

3

In [58]:
img_dir = '/data/violet/st/pdac_ffpe_raw'
weights = '/home/estorrs/violet/sandbox/dino_runs/pdac_ffpe_tcia_raw_augmented_20samples/checkpoint0060.pth'
run_dir = '/home/estorrs/violet/sandbox/runs/pdac_ffpe_tcia_raw_augmented_20samples'

In [59]:
def get_target_df(folder):
    target_df = None
    fps = listfiles(folder, regex=r'_sp.h5ad$')
    for fp in fps:
        a = sc.read_h5ad(fp)
        sample = fp.split('/')[-1].split('.')[0].split('_sp')[0]
        df = a.obsm['tangram_ct_pred']
#         df = pd.DataFrame(data=df.values / np.max(df.values, axis=0),
#                          columns=df.columns, index=df.index)
#         X = StandardScaler().fit_transform(df.values)
#         df = pd.DataFrame(data=X + np.abs(np.min(X)),
#                          columns=df.columns, index=df.index)
        X = np.log1p(df.values)
        df = pd.DataFrame(data=X / np.max(X, axis=0),
                          columns=df.columns, index=df.index)
        df.index = [f'{sample}_{x}' for x in df.index]
        
        if target_df is None:
            target_df = df
        else:
            target_df = pd.concat((target_df, df))
    return target_df

In [60]:
fmap = pd.read_csv('/home/estorrs/spatial-analysis/data/sample_map.txt', sep='\t', index_col=0)
fmap

Unnamed: 0_level_0,spaceranger_output,highres_image,disease,tissue_type
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
HT206B1_H8_U2,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/b...,brca,oct
HT206B1_H8_U3,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/b...,brca,oct
HT206B1_H8_U4,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/b...,brca,oct
HT206B1_H8_U5,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/b...,brca,oct
HT206B1_H8_Bn,/data/spatial_transcriptomics/spaceranger_outp...,,brca,oct
...,...,...,...,...
NMK_20201012,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/m...,mouse_kidney,oct
AKICL_14w,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/m...,mouse_kidney,oct
AKI_M_14w,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/m...,mouse_kidney,oct
SP1896H1_U1,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/i...,normal_stomach,oct


In [61]:
adata_map = {}
for s, fp, d, t in zip(fmap.index, fmap['spaceranger_output'], fmap['disease'], fmap['tissue_type']):
    if 'pdac' == d and t == 'ffpe':
        adata_map[s] = fp
adata_map

{'HT264P1_S1H2Fs1_U1': '/data/spatial_transcriptomics/spaceranger_outputs/pancreatic/HT264P1-S1H2Fs1U1Bp1',
 'HT270P1_S1H1Fs5U1': '/data/spatial_transcriptomics/spaceranger_outputs/pancreatic/HT270P1-S1H1Fs5U1Bp1'}

In [62]:
# small = ['HT206B1_H8_U2', 'HT206B1_H8_U3']
# adata_map = {k:v for k, v in adata_map.items() if k in small}

In [63]:
val_samples = ['HT264P1_S1H2Fs1_U1']

In [64]:
target_df = get_target_df('/home/estorrs/tangram_annotation/results/pdac_ffpe/')
target_df

Unnamed: 0,Monocyte,Fibroblast,CD8 T cell,Treg,Epithelial,Plasma,NK,Dendritic,Endothelial,Malignant,Tuft,Acinar,B cell,Erythrocyte,Islet,CD4 T cell,Mast
HT264P1_S1H2Fs1_U1_AAACAAGTATCTCCCA-1,0.197496,0.335588,0.190963,0.492354,0.000727,0.037626,0.209403,0.057444,0.195201,0.097795,0.000209,0.123608,0.139343,0.001678,0.244454,0.310614,0.013886
HT264P1_S1H2Fs1_U1_AAACAGAGCGACTCCT-1,0.231423,0.004139,0.058833,0.000594,0.000181,0.065302,0.209016,0.000323,0.296314,0.001164,0.000106,0.100393,0.078956,0.012263,0.084652,0.133947,0.000989
HT264P1_S1H2Fs1_U1_AAACAGTGTTCCTGGG-1,0.001427,0.250531,0.056429,0.000464,0.000169,0.001107,0.000569,0.000279,0.374196,0.001317,0.000234,0.301653,0.001488,0.168700,0.074898,0.001489,0.242065
HT264P1_S1H2Fs1_U1_AAACATTTCCCGGATT-1,0.295681,0.352007,0.322431,0.220487,0.000312,0.002148,0.001117,0.000569,0.297659,0.322470,0.000198,0.078337,0.139421,0.002820,0.108844,0.134540,0.427028
HT264P1_S1H2Fs1_U1_AAACCCGAACGAAATC-1,0.177384,0.312500,0.309823,0.000926,0.204923,0.002046,0.209044,0.000533,0.258758,0.362486,0.000260,0.724483,0.397104,0.029003,0.202363,0.305218,0.033817
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
HT270P1_S1H1Fs5U1_TTGTTGTGTGTCAAGA-1,0.180117,0.455171,0.338833,0.002550,0.000631,0.005643,0.001462,0.182932,0.391215,0.001921,0.000247,0.001230,0.260692,0.029664,0.000589,0.306876,0.334085
HT270P1_S1H1Fs5U1_TTGTTTCACATCCAGG-1,0.348536,0.447388,0.290224,0.275600,0.000650,0.006281,0.009072,0.115008,0.444863,0.384960,0.254777,0.057527,0.352485,0.004224,0.000266,0.475488,0.153015
HT270P1_S1H1Fs5U1_TTGTTTCATTAGTCTA-1,0.282419,0.449857,0.331486,0.044486,0.000750,0.007361,0.239195,0.001466,0.433042,0.428319,0.000311,0.039249,0.231990,0.013249,0.000308,0.571650,0.224206
HT270P1_S1H1Fs5U1_TTGTTTCCATACAACT-1,0.149052,0.296272,0.327674,0.002025,0.000381,0.198074,0.001188,0.002369,0.275183,0.001615,0.000176,0.047301,0.124609,0.015668,0.000211,0.269908,0.146062


In [65]:
target_df = target_df[['CD8 T cell', 'Malignant', 'Fibroblast']]

In [66]:
learner = load_st_learner(img_dir, weights, adata_map, run_dir, 
                          model_name='xcit_small', patch_size=8,
                          val_samples=val_samples, target_df=target_df,
                          frozen_lr=1e-4, unfrozen_lr=1e-5, batch_size=16)



Take key teacher in provided checkpoint dict
Pretrained weights found at /home/estorrs/violet/sandbox/dino_runs/pdac_ffpe_tcia_raw_augmented_20samples/checkpoint0060.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
ST Learner summary:
{'dataset': {'image_directory': '/data/violet/st/pdac_ffpe_raw',
             'min_counts': 2500,
             'resolution': 55.0,
             'targets': ['CD8 T cell', 'Malignant', 'Fibroblast'],
             'train_dataset': {'num_spots': 3940,
                               'samples': ['HT270P1_S1H1Fs5U1']},
             'val_dataset': {'num_spots': 3234,
                             'samples': ['HT264P1_S1H2Fs1_U1']}},
 'head': {'type': 'linear'},
 'hyperparams': {'batch_size': 16,
                 'frozen_lr': 0.0001,
                 'gpu': True,


In [67]:
# import torch.distributed as dist
# dist.init_process_group('gloo', init_method='file:///tmp/somefile', rank=0, world_size=1)

In [68]:
run_st_learner(learner, 5, 20, save_every=2)

Training frozen vit for 5 epochs
epoch: 0, train loss: 0.9307611584663391, val loss: 0.7378860116004944, time: 103.3515293598175
epoch: 1, train loss: 0.6188583374023438, val loss: 0.4977882504463196, time: 125.95883345603943
epoch: 2, train loss: 0.4297819435596466, val loss: 0.32395240664482117, time: 126.15314936637878
epoch: 3, train loss: 0.3246530592441559, val loss: 0.2671273946762085, time: 126.18200874328613
epoch: 4, train loss: 0.2602197229862213, val loss: 0.22414107620716095, time: 106.65325784683228
saving checkpoint at /home/estorrs/violet/sandbox/runs/pdac_ffpe_tcia_raw_augmented_20samples/checkpoints/1.finetune_frozen_vit_last_frozen.pth
Saved checkpoint at /home/estorrs/violet/sandbox/runs/pdac_ffpe_tcia_raw_augmented_20samples/checkpoints/1.finetune_frozen_vit_last_frozen.pth
Unfreezing weights
Training unfrozen vit for 20 epochs
epoch: 0, train loss: 0.1599694937467575, val loss: 0.12102872133255005, time: 182.74432849884033
saving checkpoint at /home/estorrs/violet

In [None]:
0.0058

In [18]:
learner.summary

{'run_directory': '/home/estorrs/violet/sandbox/runs/he_ffpe_pda_xcit_p8_normalized',
 'dataset': {'image_directory': '/home/estorrs/violet/data/st/pdac_ffpe_normalized',
  'targets': ['BGN',
   'CD14',
   'CD3E',
   'CD3G',
   'CD4',
   'CD68',
   'CD8A',
   'CDH1',
   'CDH5',
   'CPA3',
   'CTLA4',
   'EPCAM',
   'FAP',
   'FCER1A',
   'FOXP3',
   'GNLY',
   'HAVCR2',
   'HBA1',
   'IL7R',
   'KIT',
   'KRT18',
   'LAG3',
   'LYZ',
   'MKI67',
   'MS4A1',
   'MUC5AC',
   'NKG7',
   'PDCD1',
   'PECAM1',
   'PRSS1',
   'PTPRC',
   'RGS5',
   'SDC1',
   'SPARC',
   'TIGIT',
   'TOP2A'],
  'resolution': 55.0,
  'min_counts': 2500,
  'train_dataset': {'samples': ['HT270P1_S1H1Fs5U1'], 'num_spots': 3907},
  'val_dataset': {'samples': ['HT264P1_S1H2Fs1_U1'], 'num_spots': 780}},
 'vit': {'pretrained_weights': '/home/estorrs/violet/sandbox/dino_runs/he_ffpe_pda_xcit_p8_normalized/checkpoint.pth',
  'patch_size': 8,
  'img_size': (224, 224),
  'total_patches': 784,
  'embed_dim': 384,
  'mode

###### with alternate model

In [17]:
img_dir = '/home/estorrs/spatial-analysis/data/breast/model_inputs_06092021/he_imgs_v2'
weights = '/home/estorrs/dino/outputs/test_run_5_brca_good_only/checkpoint0480.pth'
run_dir = '/home/estorrs/violet/sandbox/runs/test_run_renet50'

In [34]:
learner = load_imagenet_st_learner(img_dir, weights, adata_map, run_dir, 
    val_samples=val_samples, targets=markers, model_name='resnet50')

Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.


ST Learner summary:
{'convnet': {'embed_dim': 1000,
             'imagenet_pretrained': True,
             'model_name': 'resnet50'},
 'dataset': {'image_directory': '/home/estorrs/spatial-analysis/data/breast/model_inputs_06092021/he_imgs_v2',
             'min_counts': 2500,
             'resolution': 55.0,
             'targets': ['ESR1',
                         'PGR',
                         'ERBB2',
                         'MKI67',
                         'TOP2A',
                         'CD3G',
                         'CD4',
                         'CD8A',
                         'KIT',
                         'EPCAM',
                         'CDH1',
                         'BGN',
                         'FAP',
                         'SPARC',
                         'ITGAX',
                         'LYZ',
                         'CD68',
                         'CD14',
                         'SDC1',
                         'PECAM1',
                         'I

In [35]:
run_st_learner(learner, 2, 3)

Training frozen vit for 2 epochs
epoch: 0, train loss: 1.567744493484497, val loss: 1.6477304697036743
epoch: 1, train loss: 1.393233060836792, val loss: 1.477431058883667
Saved checkpoint at /home/estorrs/violet/sandbox/runs/test_run_renet50/checkpoints/1.finetune_frozen_vit.pth
Unfreezing weights
Training unfrozen vit for 3 epochs
epoch: 0, train loss: 0.7829263806343079, val loss: 0.5886406302452087
epoch: 1, train loss: 0.4148525297641754, val loss: 0.4237063527107239
epoch: 2, train loss: 0.30059632658958435, val loss: 0.33077406883239746
Saved final checkpoint at /home/estorrs/violet/sandbox/runs/test_run_renet50/checkpoints/final.pth
Saved summary at /home/estorrs/violet/sandbox/runs/test_run_renet50/summary.json


In [8]:
sum(p.numel() for p in m.parameters())

25557032

###### sandbox

In [52]:
def process_adata(sid, fp, n_top=200, count_filter=2500):
    a = sc.read_visium(fp)
    a.var_names_make_unique()
    a.var["mt"] = a.var_names.str.startswith("MT-")
    sc.pp.calculate_qc_metrics(a, qc_vars=["mt"], inplace=True)
    sc.pp.filter_cells(a, min_counts=count_filter)
    a.obs.index = [f'{sid}_{x}' for x in a.obs.index]
    sc.pp.normalize_total(a, inplace=True)
    sc.pp.log1p(a)
    sc.pp.highly_variable_genes(a, flavor="seurat", n_top_genes=200)
    
    f = a[:, markers]
    
    df = pd.DataFrame(f.X.toarray(), columns=f.var.index, index=f.obs.index)
    
    return df

In [53]:
fmap = pd.read_csv('/home/estorrs/spatial-analysis/data/sample_map.txt', sep='\t', index_col=0)
fmap

Unnamed: 0_level_0,spaceranger_output,highres_image,disease
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
HT206B1_H8_U2,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/b...,brca
HT206B1_H8_U3,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/b...,brca
HT206B1_H8_U4,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/b...,brca
HT206B1_H8_U5,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/b...,brca
HT206B1_H8_Bn,/data/spatial_transcriptomics/spaceranger_outp...,,brca
...,...,...,...
NMK_20201012,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/m...,mouse_kidney
AKICL_14w,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/m...,mouse_kidney
AKI_M_14w,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/m...,mouse_kidney
SP1896H1_U1,/data/spatial_transcriptomics/spaceranger_outp...,/data/spatial_transcriptomics/highres_images/i...,normal_stomach


In [55]:
targets = None
for s, fp, d in zip(fmap.index, fmap['spaceranger_output'], fmap['disease']):
    if 'brca' in d:
        df = process_adata(s, fp)
        if targets is None:
            targets = df
        else:
            targets = pd.concat((targets, df), axis=0)
targets

Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Vari

Unnamed: 0,ESR1,PGR,ERBB2,MKI67,TOP2A,CD3G,CD4,CD8A,KIT,EPCAM,...,BGN,FAP,SPARC,ITGAX,LYZ,CD68,CD14,SDC1,PECAM1,IL7R
HT206B1_H8_U2_AAACAACGAATAGTTC-1,0.528794,0.0,0.000000,0.000000,0.528794,0.000000,0.528794,0.00000,0.000000,1.128382,...,2.308517,0.000000,2.703887,0.528794,0.528794,1.128382,0.872869,0.872869,2.075539,0.000000
HT206B1_H8_U2_AAACAAGTATCTCCCA-1,0.000000,0.0,0.699764,0.000000,0.699764,0.000000,1.107425,0.00000,0.699764,0.699764,...,2.577131,0.000000,1.620004,0.000000,0.699764,1.107425,1.396203,1.620004,0.000000,0.000000
HT206B1_H8_U2_AAACAATCTACTAGCA-1,0.000000,0.0,0.000000,0.531650,0.531650,0.000000,1.133083,0.53165,0.000000,0.531650,...,2.081613,0.000000,2.081613,0.000000,0.876917,1.336830,1.506007,0.531650,1.133083,0.531650
HT206B1_H8_U2_AAACAGAGCGACTCCT-1,0.000000,0.0,0.422854,0.000000,0.422854,0.000000,0.719118,0.00000,0.000000,0.422854,...,2.397888,0.422854,1.834677,0.000000,1.544191,1.425003,0.422854,0.422854,1.289662,0.719118
HT206B1_H8_U2_AAACAGGGTCTATATT-1,0.000000,0.0,1.030038,1.030038,0.000000,0.000000,0.000000,0.00000,0.000000,1.526566,...,1.030038,0.000000,0.000000,1.030038,1.030038,1.030038,1.526566,1.526566,0.000000,1.030038
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
HT268B1Th1_H3U1_TTGTTGTGTGTCAAGA-1,0.000000,0.0,0.286372,0.508728,0.977550,0.000000,0.000000,0.00000,0.000000,1.295467,...,0.286372,0.000000,1.295467,0.000000,0.000000,0.508728,0.286372,0.000000,0.508728,0.000000
HT268B1Th1_H3U1_TTGTTTCACATCCAGG-1,0.000000,0.0,0.594420,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.594420,...,1.899695,0.340741,2.210447,0.000000,0.340741,0.000000,0.340741,0.340741,0.796592,0.000000
HT268B1Th1_H3U1_TTGTTTCATTAGTCTA-1,0.000000,0.0,0.000000,0.313226,0.904754,0.000000,0.000000,0.00000,0.000000,1.165332,...,1.165332,0.000000,1.371853,0.000000,0.904754,0.000000,0.000000,0.000000,0.000000,0.000000
HT268B1Th1_H3U1_TTGTTTCCATACAACT-1,0.000000,0.0,0.573092,0.573092,0.000000,0.573092,0.935106,0.00000,0.000000,0.935106,...,2.167615,0.573092,3.086077,0.000000,2.593752,0.935106,1.409759,0.573092,0.573092,0.573092


In [79]:
val_regexs = [r'.*' + s for s in ['HT206B1_H8_U2']]
train_dataloader, val_dataloader = image_regression_dataloaders(img_dir, targets, val_regexs=val_regexs)

In [57]:
train_dataloader.dataset.labels

array(['ESR1', 'PGR', 'ERBB2', 'MKI67', 'TOP2A', 'CD3G', 'CD4', 'CD8A',
       'KIT', 'EPCAM', 'CDH1', 'BGN', 'FAP', 'SPARC', 'ITGAX', 'LYZ',
       'CD68', 'CD14', 'SDC1', 'PECAM1', 'IL7R'], dtype=object)

In [80]:
train_dataloader.dataset.samples, len(train_dataloader.dataset.samples)

(array(['HT206B1_H8_U3_AAACAACGAATAGTTC-1',
        'HT206B1_H8_U3_AAACAAGTATCTCCCA-1',
        'HT206B1_H8_U3_AAACAATCTACTAGCA-1', ...,
        'HT268B1Th1_H3U1_TTGTTTCATTAGTCTA-1',
        'HT268B1Th1_H3U1_TTGTTTCCATACAACT-1',
        'HT268B1Th1_H3U1_TTGTTTGTGTAAATTC-1'], dtype='<U34'),
 68480)

In [81]:
val_dataloader.dataset.samples, len(val_dataloader.dataset.samples)

(array(['HT206B1_H8_U2_AAACAACGAATAGTTC-1',
        'HT206B1_H8_U2_AAACAAGTATCTCCCA-1',
        'HT206B1_H8_U2_AAACAATCTACTAGCA-1', ...,
        'HT206B1_H8_U2_TTGTTTCCATACAACT-1',
        'HT206B1_H8_U2_TTGTTTGTATTACACG-1',
        'HT206B1_H8_U2_TTGTTTGTGTAAATTC-1'], dtype='<U32'),
 3859)

In [82]:
model = load_pretrained_model(weights)

Take key teacher in provided checkpoint dict
Pretrained weights found at /home/estorrs/dino/outputs/test_run_5_brca_good_only/checkpoint0480.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])


In [84]:
regressor = STRegressor(model, len(train_dataloader.dataset.labels))

In [85]:
regressor = regressor.cuda()

In [86]:
max_lr = 1e-4
learner = STLearner(regressor, train_dataloader, val_dataloader,
                   max_lr=max_lr)

In [87]:
epochs = 10
learner.fit(epochs)

epoch: 0, train loss: 1.2852318286895752, val loss: 0.7392365336418152
epoch: 1, train loss: 0.8184218406677246, val loss: 0.5383877754211426
epoch: 2, train loss: 0.6173110008239746, val loss: 0.44746753573417664
epoch: 3, train loss: 0.5175519585609436, val loss: 0.418517142534256
epoch: 4, train loss: 0.4677029550075531, val loss: 0.41857048869132996
epoch: 5, train loss: 0.4406231641769409, val loss: 0.421322226524353
epoch: 6, train loss: 0.42264577746391296, val loss: 0.4187578856945038
epoch: 7, train loss: 0.40919333696365356, val loss: 0.4185178875923157
epoch: 8, train loss: 0.39882588386535645, val loss: 0.41533875465393066
epoch: 9, train loss: 0.39062613248825073, val loss: 0.41424310207366943


In [88]:
learner.unfreeze_vit()

In [89]:
epochs = 10
learner.fit(epochs)

epoch: 0, train loss: 0.3225911259651184, val loss: 0.3581167161464691
epoch: 1, train loss: 0.2807820737361908, val loss: 0.3240653872489929
epoch: 2, train loss: 0.26617154479026794, val loss: 0.33561214804649353
epoch: 3, train loss: 0.2576692998409271, val loss: 0.33494752645492554
epoch: 4, train loss: 0.25164785981178284, val loss: 0.3257253170013428
epoch: 5, train loss: 0.24699285626411438, val loss: 0.3269999623298645
epoch: 6, train loss: 0.2429467737674713, val loss: 0.3244311511516571
epoch: 7, train loss: 0.23941180109977722, val loss: 0.31136462092399597
epoch: 8, train loss: 0.236506849527359, val loss: 0.32133060693740845
epoch: 9, train loss: 0.23380810022354126, val loss: 0.3083851635456085


In [90]:
train_dataloader.batch_size

64

In [94]:
regressor.vit.patch_embed

PatchEmbed(
  (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
)

In [102]:
next(iter(regressor.vit.patch_embed.num_patches))

Parameter containing:
tensor([[[[-1.2361e-03,  9.9253e-03, -1.3006e-02,  ..., -1.9687e-02,
           -1.3001e-02,  1.2334e-03],
          [-1.9194e-03,  2.0495e-03, -1.9928e-02,  ..., -1.5762e-02,
           -1.7594e-02,  1.7031e-03],
          [-1.9589e-02, -1.5938e-02, -2.0802e-02,  ..., -1.9639e-02,
           -7.9108e-03, -9.2732e-03],
          ...,
          [-1.0810e-02, -5.1164e-03, -1.3453e-02,  ...,  4.8155e-03,
           -5.8476e-03, -9.5824e-03],
          [-2.5694e-02, -4.0821e-04, -2.2047e-02,  ..., -2.1958e-02,
           -1.8879e-02, -8.0126e-03],
          [-9.4168e-03,  2.2903e-03, -1.6485e-02,  ..., -1.1925e-02,
           -1.5558e-02, -9.2207e-03]],

         [[-1.3862e-03,  6.1599e-03, -1.1952e-02,  ...,  9.3583e-03,
            1.1061e-02,  3.3222e-02],
          [-1.9516e-02, -1.5952e-02, -3.1832e-02,  ...,  8.8154e-03,
           -5.7496e-04,  2.3535e-02],
          [-1.5078e-02, -2.2050e-02, -2.9011e-02,  ...,  1.2116e-03,
           -7.4587e-03,  1.1714e-02]

In [111]:
regressor.vit.patch_embed.proj.kernel_size[0]

16

In [108]:
regressor.vit.

384

In [112]:
224 / 16

14.0

In [113]:
224 / 14

16.0

In [114]:
14 * 14

196

In [None]:
num_patches = (img_size // patch_size) * (img_size // patch_size)

In [115]:
224 // 16

14

In [116]:
import pprint