In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from models.ResNetClassifier import ResNet
from models.ShallowResNetClassifier import ShallowResNet
from dataset.MultimodalityDataModule import MultimodalityDataModule
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import Trainer

import numpy as np
from pathlib import Path

In [2]:
MAX_NUMBER_WORDS = 20000       # number of words to consider from embeddings vocabulary
MAX_WORDS_PER_SENTENCE = 300   # sentence maximum length
NUM_CLASSES = 4                # 4 microscopy classes
BASE_PATH = Path('/workspace/data')
DATA_PATH = BASE_PATH / 'multimodality_classification.csv'
OUTPUT_DIR = Path('./outputs')
BASE_IMG_DIR = BASE_PATH       # the image path in the CSV file are relative to this directory
NUM_WORKERS = 72   

In [3]:
dm = MultimodalityDataModule(32,
                             str(DATA_PATH),
                             MAX_NUMBER_WORDS,
                             MAX_WORDS_PER_SENTENCE,
                             str(BASE_IMG_DIR),
                             num_workers=NUM_WORKERS)
dm.prepare_data()
dm.setup()

class_weights = dm.class_weights

In [4]:
output_dir = './temp'

name = 'resnet50'
fine_tuned_from="whole"
lr = 5e-6
resnet = ResNet(name, 4, fine_tuned_from=fine_tuned_from, lr=lr, class_weights=class_weights)
np.sum(np.array([p.requires_grad for p in resnet.model.parameters()]))

161

In [10]:
layers = [2, 2, 2, 2]
shallow_resnet = ShallowResNet('resnet18', 4, layers, class_weights=dm.class_weights)

In [11]:
print(shallow_resnet)

ShallowResNet(
  (model): ShallowTorchResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine

In [5]:
from pytorch_lightning.loggers import WandbLogger
PROJECT = 'biomedical-multimodal'
wandb_logger = WandbLogger(project=PROJECT, tags=['nb', 'image', 'resnet152'])
wandb_logger.experiment.save()
print(wandb_logger.experiment.name)

output_run_path = OUTPUT_DIR / wandb_logger.experiment.name 
os.makedirs(output_run_path, exist_ok=False)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0,
    patience=5,
    verbose=True,
    mode='min'
)

trainer = Trainer(gpus=1, logger=wandb_logger, max_epochs=200, default_root_dir=output_dir,
                  early_stop_callback=early_stop_callback)
trainer.fit(resnet, dm)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Wandb version 0.10.4 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


chocolate-darkness-206



  | Name      | Type             | Params
-----------------------------------------------
0 | model     | ResNet           | 23 M  
1 | criterion | CrossEntropyLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..
Epoch 00018: early stopping triggered.





1

In [6]:
trainer.save_checkpoint("./temp/example.ckpt")

In [7]:
resnet2 = ResNet.load_from_checkpoint("./temp/example.ckpt") # fine_tuned_from="whole"
resnet2.hparams

"class_weights":   [2.25120773 0.62050599 0.74085851 1.68231047]
"fine_tuned_from": whole
"lr":              5e-06
"name":            resnet50
"num_classes":     4
"pretrained":      True

In [8]:
np.sum(np.array([p.requires_grad for p in resnet2.model.parameters()]))

161

In [9]:
for p in resnet2.model.fc.parameters():
    print(p.data)

tensor([[-0.0131, -0.0052, -0.0219,  ...,  0.0208, -0.0127,  0.0187],
        [ 0.0210, -0.0136, -0.0188,  ...,  0.0063,  0.0225,  0.0140],
        [ 0.0166,  0.0114, -0.0210,  ..., -0.0194, -0.0205,  0.0209],
        [-0.0118, -0.0068, -0.0058,  ...,  0.0083, -0.0021,  0.0087]])
tensor([-0.0130, -0.0109,  0.0219,  0.0201])


In [10]:
np.sum(np.array([p.requires_grad for p in resnet.model.parameters()]))

161

In [11]:
for p in resnet.model.fc.parameters():
    print(p.data)

tensor([[-0.0131, -0.0052, -0.0219,  ...,  0.0208, -0.0127,  0.0187],
        [ 0.0210, -0.0136, -0.0188,  ...,  0.0063,  0.0225,  0.0140],
        [ 0.0166,  0.0114, -0.0210,  ..., -0.0194, -0.0205,  0.0209],
        [-0.0118, -0.0068, -0.0058,  ...,  0.0083, -0.0021,  0.0087]])
tensor([-0.0130, -0.0109,  0.0219,  0.0201])
