In [2]:
from functools import partial
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from dataset import ChestXray14
from model import get_encoder
from losses import ContrastiveLoss, SupConLoss

import math
from sklearn.metrics import auc, roc_curve
import warnings

In [7]:


warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 32
epoch_count = 100

encoder_choice = 'vit'
model = get_encoder(encoder_choice=encoder_choice)

model = model.to(device)

criterion = nn.CrossEntropyLoss() 

optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train(data_loader, class_name):

    for epoch in range(1, epoch_count + 1):
        progress_bar = tqdm.tqdm(data_loader)
        correct = 0
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            features = model(images)
            loss = criterion(features, labels)  
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            progress_bar.set_description(f"Loss: {loss.item()}")


        torch.save(model.state_dict(), 'weights/{}_{}_weights'.format(encoder_choice, class_name))
        print("Epoch: {} | Loss: {:.2f} | Accuracy: ".format(epoch, loss.item()))


class_name = 'Cardiomegaly'
train_dataset = ChestXray14(phase='train', class_name=class_name)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

print('Training on {} images ({})'.format(len(train_dataset), class_name))
train(data_loader=train_data_loader, class_name=class_name)



Using cache found in /home/developer/.cache/torch/hub/facebookresearch_deit_main


Training on 3330 images (Cardiomegaly)


Loss: 2.05824875831604:   4%|▋                  | 4/105 [00:00<00:15,  6.42it/s]

tensor([[ 1.1213,  0.5559],
        [ 1.2161,  0.4791],
        [ 0.7712,  0.5057],
        [ 0.9049,  0.2789],
        [ 0.7394,  0.5842],
        [ 1.5370,  0.0813],
        [ 0.7979,  0.1009],
        [ 2.0940,  0.2045],
        [ 0.8360,  0.5392],
        [ 1.1018,  0.0466],
        [ 0.9446, -0.1795],
        [ 1.2894,  0.1532],
        [ 1.6401, -0.0229],
        [ 0.8889,  0.6681],
        [ 0.9597, -0.0743],
        [ 1.5979,  0.3502],
        [ 1.3344, -0.5556],
        [ 0.7029, -0.1286],
        [ 0.6751,  0.1476],
        [ 1.1645,  0.5500],
        [ 1.8527,  0.7671],
        [ 1.0231,  0.1176],
        [ 1.1577,  0.2009],
        [ 1.7747, -0.2238],
        [ 0.8264,  0.3147],
        [ 2.1178,  0.1304],
        [ 0.8476,  0.2040],
        [ 1.3224,  0.0465],
        [ 1.3367,  1.0960],
        [ 1.2095, -0.1130],
        [ 1.5141,  0.1632],
        [ 1.5841,  0.2294]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0

Loss: 0.8089125156402588:   6%|▉                | 6/105 [00:01<00:16,  5.95it/s]

tensor([[0.3070, 0.3090],
        [0.2850, 0.3601],
        [0.2545, 0.3674],
        [0.2878, 0.3409],
        [0.2938, 0.3234],
        [0.3224, 0.2804],
        [0.3163, 0.2960],
        [0.2764, 0.3487],
        [0.2897, 0.3495],
        [0.2541, 0.3746],
        [0.2756, 0.3478],
        [0.2765, 0.3636],
        [0.2539, 0.3671],
        [0.2655, 0.3536],
        [0.2954, 0.3332],
        [0.2727, 0.3609],
        [0.2938, 0.3374],
        [0.3261, 0.2837],
        [0.2664, 0.3570],
        [0.2795, 0.3359],
        [0.2882, 0.3503],
        [0.2961, 0.3319],
        [0.2603, 0.3846],
        [0.3267, 0.2804],
        [0.2784, 0.3390],
        [0.2738, 0.3617],
        [0.3172, 0.3008],
        [0.2750, 0.3765],
        [0.3035, 0.3200],
        [0.2658, 0.3640],
        [0.2870, 0.3334],
        [0.2673, 0.3541]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1,
        0, 0, 1, 0, 1, 1, 1, 1], device='cud

Loss: 1.0235835313796997:  11%|█▊              | 12/105 [00:01<00:10,  9.21it/s]

tensor([[ 0.4919, -0.1316],
        [ 0.5396, -0.1331],
        [ 0.5549, -0.2002],
        [ 0.5071, -0.1469],
        [ 0.5061, -0.1405],
        [ 0.5238, -0.1309],
        [ 0.4931, -0.1406],
        [ 0.5464, -0.1291],
        [ 0.5239, -0.1330],
        [ 0.5309, -0.1341],
        [ 0.5012, -0.1380],
        [ 0.5198, -0.1189],
        [ 0.5412, -0.1285],
        [ 0.4841, -0.1510],
        [ 0.5206, -0.1651],
        [ 0.5365, -0.1329],
        [ 0.5271, -0.1301],
        [ 0.5384, -0.1330],
        [ 0.5404, -0.1311],
        [ 0.5226, -0.1288],
        [ 0.5421, -0.1283],
        [ 0.5482, -0.1232],
        [ 0.5221, -0.1258],
        [ 0.5438, -0.1307],
        [ 0.5341, -0.1360],
        [ 0.5392, -0.1264],
        [ 0.5449, -0.1240],
        [ 0.5000, -0.1194],
        [ 0.5094, -0.1389],
        [ 0.5069, -0.1223],
        [ 0.4984, -0.1405],
        [ 0.5446, -0.1305]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0

Loss: 0.6929029822349548:  13%|██▏             | 14/105 [00:02<00:13,  6.61it/s]

tensor([[-0.0761,  0.6264],
        [-0.0860,  0.6085],
        [-0.0862,  0.6227],
        [-0.0861,  0.6335],
        [-0.0833,  0.6210],
        [-0.0648,  0.6417],
        [-0.0887,  0.6228],
        [-0.0694,  0.6407],
        [-0.0861,  0.6315],
        [-0.0709,  0.6220],
        [-0.0738,  0.6227],
        [-0.0800,  0.6248],
        [-0.0838,  0.6282],
        [-0.0642,  0.6474],
        [-0.0741,  0.6215],
        [-0.0866,  0.6223],
        [-0.0615,  0.6452],
        [-0.0750,  0.6232],
        [-0.0831,  0.6278],
        [-0.0858,  0.6292],
        [-0.0871,  0.5961],
        [-0.0867,  0.6134],
        [-0.0880,  0.6265],
        [-0.0690,  0.6263],
        [-0.0704,  0.6258],
        [-0.0824,  0.6460],
        [-0.0747,  0.6425],
        [-0.0781,  0.6225],
        [-0.0805,  0.6350],
        [-0.0865,  0.6218],
        [-0.0871,  0.6261],
        [-0.0673,  0.6419]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1




KeyboardInterrupt: 

In [4]:
features

NameError: name 'features' is not defined