In [1]:
from tqdm.auto import tqdm
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from dataset import CANDataset
from networks.simple_cnn import SupConCNN
from SupContrast.networks.resnet_big import SupConResNet, SupCEResNet

from SupContrast.util import AverageMeter
from SupContrast.util import accuracy

torch.manual_seed(0)

<torch._C.Generator at 0x7fea4184dab0>

In [2]:
NUM_SOURCE_CLASSES = 5
NUM_TARGET_CLASSES = 4
window_size = 29

In [3]:
train_dataset = CANDataset('../Data/Survival/TFrecord_KIA_w29_s29/1/', 
                           window_size = window_size)
val_dataset = CANDataset('../Data/Survival/TFrecord_KIA_w29_s29/1/', 
                           window_size = window_size,
                          is_train=False)

train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=256, 
            shuffle=True, num_workers=10, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=256, num_workers=8, pin_memory=True)

In [4]:
def change_new_state_dict_parallel(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        k = k.replace("module.", "")
        new_state_dict[k] = v
    return new_state_dict

In [5]:
# save_path = './save/save_1/SupCon_cnn_lr0.1_0.5_bs1024_200epoch_temp0.07_033122_154210_cosine_warm/models/'
save_path = './save/SupCon_resnet18_lr0.005_0.001_bs4096_300epoch_temp0.07_041522_144805_cosine_warm/models/'
ckpt_epoch = 200
model_path = f'{save_path}/ckpt_epoch_{ckpt_epoch}.pth'
ckpt = torch.load(model_path)
state_dict = ckpt['model']
state_dict = change_new_state_dict_parallel(state_dict)
# model = SupConCNN(feat_dim=128)
model = SupConResNet(name='resnet18')
model.load_state_dict(state_dict=state_dict)

<All keys matched successfully>

## Train the classifier with fixed pretrained model first

In [6]:
class LinearClassifier(nn.Module):
    def __init__(self, n_classes, feat_dim):
        super().__init__()
        self.n_classes = n_classes
        self.fc = nn.Linear(feat_dim, n_classes)
        
    def forward(self, x):
        output = self.fc(x)
        return output

In [7]:
classifier = LinearClassifier(n_classes=4, feat_dim=512)
criterion = torch.nn.CrossEntropyLoss()

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()
    classifier = classifier.cuda()

In [8]:
optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9, weight_decay=0)

In [9]:
def train(train_loader, model, classifier, criterion, optimizer, epoch):
    model.eval()
    classifier.train()
    
    losses = AverageMeter()
    accs = AverageMeter()
    
    for indx, (inputs, labels) in tqdm(enumerate(train_loader)):
        inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]
        
        with torch.no_grad():
            feats = model.encoder(inputs)
        
        outputs = classifier(feats.detach())
        loss = criterion(outputs, labels)
        acc = accuracy(outputs, labels, topk=(1, ))
        
        losses.update(loss.item(), bsz)
        accs.update(acc[0].item(), bsz)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return losses.avg, accs.avg

In [10]:
EPOCHS = 30
for epoch in range(1, EPOCHS + 1):
    loss, acc = train(train_loader, model, classifier, criterion, optimizer, epoch)
    print(f'Epoch {epoch}: loss={loss}, acc={acc}')

0it [00:00, ?it/s]

Epoch 1: loss=0.4509581571292231, acc=83.9460532621346


0it [00:00, ?it/s]

Epoch 2: loss=0.3336355446516435, acc=87.14315054164804


0it [00:00, ?it/s]

Epoch 3: loss=0.30316376011172647, acc=87.49914424329795


0it [00:00, ?it/s]

Epoch 4: loss=0.28556842270667543, acc=88.17005545025351


0it [00:00, ?it/s]

Epoch 5: loss=0.2930581980766992, acc=88.24536180760636


0it [00:00, ?it/s]

Epoch 6: loss=0.2685751601127357, acc=89.14903813240227


0it [00:00, ?it/s]

Epoch 7: loss=0.253683182373342, acc=89.73095091394536


0it [00:00, ?it/s]

Epoch 8: loss=0.2376763757287399, acc=90.8673923409123


0it [00:00, ?it/s]

Epoch 9: loss=0.23574318062837196, acc=90.90846854247964


0it [00:00, ?it/s]

Epoch 10: loss=0.2251367190477988, acc=91.40822893133429


0it [00:00, ?it/s]

Epoch 11: loss=0.22218766960816486, acc=91.51091941697173


0it [00:00, ?it/s]

Epoch 12: loss=0.21468342546259606, acc=92.01067980582638


0it [00:00, ?it/s]

Epoch 13: loss=0.21012566658275658, acc=92.19552269198692


0it [00:00, ?it/s]

Epoch 14: loss=0.2047877238123962, acc=92.27767508467537


0it [00:00, ?it/s]

Epoch 15: loss=0.2027207408198555, acc=92.56520846169644


0it [00:00, ?it/s]

Epoch 16: loss=0.19973680261118812, acc=92.68843704811755


0it [00:00, ?it/s]

Epoch 17: loss=0.2081879438435696, acc=92.32559731635517


0it [00:00, ?it/s]

Epoch 18: loss=0.18816168352062798, acc=93.31142602861642


0it [00:00, ?it/s]

Epoch 19: loss=0.1880901898782487, acc=93.20188950503183


0it [00:00, ?it/s]

Epoch 20: loss=0.1864024410585999, acc=93.47573080877018


0it [00:00, ?it/s]

Epoch 21: loss=0.18881880747924687, acc=93.1197371123434


0it [00:00, ?it/s]

Epoch 22: loss=0.18218063325671238, acc=93.46888477604615


0it [00:00, ?it/s]

Epoch 23: loss=0.18737495637434973, acc=93.24296570137605


0it [00:00, ?it/s]

Epoch 24: loss=0.18489987529294483, acc=93.31142602861642


0it [00:00, ?it/s]

Epoch 25: loss=0.1769632895573328, acc=93.72903401955953


0it [00:00, ?it/s]

Epoch 26: loss=0.1733937021542552, acc=93.85910864653933


0it [00:00, ?it/s]

Epoch 27: loss=0.17035570158098745, acc=93.92756896855659


0it [00:00, ?it/s]

Epoch 28: loss=0.17276043796233062, acc=93.93441500650373


0it [00:00, ?it/s]

Epoch 29: loss=0.17085412722467325, acc=93.98918326829602


0it [00:00, ?it/s]

Epoch 30: loss=0.1672974519873821, acc=94.24933250919786


In [11]:
total_pred = np.empty(shape=(0), dtype=int)
total_label = np.empty(shape=(0), dtype=int)
# total_embs = np.empty(shape=(0, 512), dtype=float)

model.eval()
classifier.eval()
with torch.no_grad():
    for images, labels in tqdm(val_loader):
        images = images.cuda(non_blocking=True)
        embs = model.encoder(images)
        outputs = classifier(embs)
        _, pred = outputs.topk(1, 1, True, True)
        pred = pred.t().cpu().numpy().squeeze(0)
        embs = embs.cpu().numpy()
        # total_embs = np.concatenate((total_embs, embs), axis=0)
        total_pred = np.concatenate((total_pred, pred), axis=0)
        total_label = np.concatenate((total_label, labels), axis=0)

  0%|          | 0/25 [00:00<?, ?it/s]

In [13]:
from utils import cal_metric
cm, results = cal_metric(total_label, total_pred)
for key, values in results.items():
    print(key, list("{0:0.4f}".format(i) for i in values)) 

fnr ['8.6269', '0.2574', '3.8431', '11.5142']
rec ['0.9137', '0.9974', '0.9616', '0.8849']
pre ['0.9725', '1.0000', '0.9017', '0.7276']
f1 ['0.9422', '0.9987', '0.9306', '0.7986']


# Fine-tuning whole model

In [8]:
import copy

In [9]:
class TransferModel(nn.Module):
    def __init__(self, feat_extractor, classifier):
        super().__init__()
        self.encoder = copy.deepcopy(feat_extractor)
        self.classifier = copy.deepcopy(classifier)
        
    def forward(self, x):
        output = self.encoder(x)
        output = self.classifier(output)
        return output

In [10]:
fine_tuned_model = TransferModel(model.encoder, classifier)
fine_tuned_model = fine_tuned_model.cuda()

In [11]:
optimizer_whole = optim.SGD(fine_tuned_model.parameters(), lr=0.0001, momentum=0.9, weight_decay=0)

In [12]:
def train_whole(train_loader, model, criterion, optimizer, epoch):
    model.train()
    
    losses = AverageMeter()
    accs = AverageMeter()
    
    for indx, (inputs, labels) in tqdm(enumerate(train_loader)):
        inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]
        
        # print(inputs.shape)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        acc = accuracy(outputs, labels, topk=(1, ))
        
        losses.update(loss.item(), bsz)
        accs.update(acc[0].item(), bsz)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return losses.avg, accs.avg

In [13]:
EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
    loss, acc = train_whole(train_loader, fine_tuned_model, criterion, optimizer_whole, epoch)
    print(f'Epoch {epoch}: loss={loss}, acc={acc}')

0it [00:00, ?it/s]

Epoch 1: loss=0.9785805408115822, acc=67.04319846648868


0it [00:00, ?it/s]

Epoch 2: loss=0.5737567878507235, acc=87.13630450892401


0it [00:00, ?it/s]

Epoch 3: loss=0.46272274787530293, acc=87.77298554964784


0it [00:00, ?it/s]

Epoch 4: loss=0.4091925606873194, acc=88.30013007462176


0it [00:00, ?it/s]

Epoch 5: loss=0.3728380123023389, acc=88.56712535085917


0it [00:00, ?it/s]

Epoch 6: loss=0.34206514444477365, acc=88.73143013101294


0it [00:00, ?it/s]

Epoch 7: loss=0.30920137063575703, acc=88.78619839802835


0it [00:00, ?it/s]

Epoch 8: loss=0.27848169872703155, acc=89.08742383788595


0it [00:00, ?it/s]

Epoch 9: loss=0.2492618537327833, acc=90.40186211828939


0it [00:00, ?it/s]

Epoch 10: loss=0.22224487276050175, acc=92.47621003106086


In [14]:
total_pred = np.empty(shape=(0), dtype=int)
total_label = np.empty(shape=(0), dtype=int)

fine_tuned_model.eval()
with torch.no_grad():
    for images, labels in tqdm(val_loader):
        images = images.cuda(non_blocking=True)
        outputs = fine_tuned_model(images)
        
        _, pred = outputs.topk(1, 1, True, True)
        
        pred = pred.t().cpu().numpy().squeeze(0)
        total_pred = np.concatenate((total_pred, pred), axis=0)
        total_label = np.concatenate((total_label, labels), axis=0)

  0%|          | 0/25 [00:00<?, ?it/s]

In [15]:
from utils import cal_metric
cm, results = cal_metric(total_label, total_pred)
for key, values in results.items():
    print(key, list("{0:0.4f}".format(i) for i in values)) 

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fea40dd90e0>
Traceback (most recent call last):
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1301, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/multiprocessing/popen_fork.py", line 45, in wait
    if not wait([self.sentinel], timeout):
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/multiprocessing/connection.py", line 921, in wait
    ready = selector.select(timeout)
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/selectors.py", line 415, in select
    fd_

fnr ['0.0832', '0.2574', '3.9231', '53.7855']
rec ['0.9992', '0.9974', '0.9608', '0.4621']
pre ['0.9332', '1.0000', '0.9063', '0.9575']
f1 ['0.9650', '0.9987', '0.9328', '0.6234']


# Cross entropy with resenet18 backbone

In [21]:
ce_model = SupCEResNet(name='resnet18', num_classes=NUM_TARGET_CLASSES)
ce_model = ce_model.cuda()
optimizer = optim.SGD(ce_model.parameters(), lr=0.0001, momentum=0.9, weight_decay=0)
criterion = torch.nn.CrossEntropyLoss()
criterion = criterion.cuda()

In [22]:
EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
    loss, acc = train_whole(train_loader, ce_model, criterion, optimizer, epoch)
    print(f'Epoch {epoch}: loss={loss}, acc={acc}')

0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7efd2a6bc0e0>
Traceback (most recent call last):
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7efd2a6bc0e0>
Traceback (most recent call last):
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/dhkim1/miniconda3/envs/torch/lib/python

Epoch 1: loss=1.1629912762407846, acc=52.63229957978045


0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7efd2a6bc0e0>
Traceback (most recent call last):
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7efd2a6bc0e0>
Traceback (most recent call last):
  File "/home/dhkim1/miniconda3/envs/torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7efd2a

Epoch 2: loss=1.0763845624983772, acc=57.554597108362636


0it [00:00, ?it/s]

Epoch 3: loss=1.038529914165005, acc=57.55459711097419


0it [00:00, ?it/s]

Epoch 4: loss=0.9970222044025365, acc=57.575135209146296


0it [00:00, ?it/s]

Epoch 5: loss=0.9518921679275892, acc=58.875881421490114


0it [00:00, ?it/s]

Epoch 6: loss=0.9051088529861527, acc=65.22215375667187


KeyboardInterrupt: 

In [18]:
from transfer import evaluate

In [19]:
evaluate(ce_model, train_loader)

  0%|          | 0/58 [00:00<?, ?it/s]

fnr ['0.0000', '0.0000', '0.0000', '0.0000']
rec ['1.0000', '1.0000', '1.0000', '1.0000']
pre ['1.0000', '1.0000', '1.0000', '1.0000']
f1 ['1.0000', '1.0000', '1.0000', '1.0000']


{'fnr': array([0., 0., 0., 0.]),
 'rec': array([1., 1., 1., 1.]),
 'pre': array([1., 1., 1., 1.]),
 'f1': array([1., 1., 1., 1.])}

In [20]:
evaluate(ce_model, val_loader)

  0%|          | 0/25 [00:00<?, ?it/s]

fnr ['0.0000', '0.0000', '0.2402', '0.1577']
rec ['1.0000', '1.0000', '0.9976', '0.9984']
pre ['0.9989', '1.0000', '1.0000', '1.0000']
f1 ['0.9994', '1.0000', '0.9988', '0.9992']


{'fnr': array([0.        , 0.        , 0.24019215, 0.15772871]),
 'rec': array([1.        , 1.        , 0.99759808, 0.99842271]),
 'pre': array([0.99889166, 1.        , 1.        , 1.        ]),
 'f1': array([0.99944552, 1.        , 0.9987976 , 0.99921073])}

In [None]:
0.2402', '0.1577']
rec ['1.0000', '1.0000', '0.9976', '0.9984']
pre ['0.9989', '1.0000', '1.0000', '1.0000']
f1 ['0.9994', '1.0000', '0.9988', '0.9992']