In [1]:
import torch
from torch import nn as nn
import os
import numpy as np
from tqdm.notebook import tqdm

In [2]:
def setup_torch(random_seed, use_gpu, gpu_number=0):
    torch.manual_seed(random_seed)
    torch.set_num_threads(8)
    if use_gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_number)
        torch.cuda.manual_seed(random_seed)

In [3]:
setup_torch(0, 1, 0)

In [4]:
import json
import glob
import matplotlib.pyplot as plt

In [103]:
# loading our data
with open('/home/colin/features_cov/covid_feat_3bcnioqn_fold_0.json') as fp:
    all_data = json.load(fp)

In [104]:
train_x = []
val_x = []
test_x = []
train_y = []
val_y = []
test_y = []
all_x = []
for patient_id, data in all_data.items():
    if 'val_features' in data:
        val_x.append(data['val_features'])
        val_y.append(data['val_label'])
    elif 'test_features' in data:
        test_x.append(data['test_features'])
        test_y.append(data['test_label'])
    else:
        all_x += data['train_features']
        train_x.append(data['train_features'])
        train_y.append(data['train_label'])

In [155]:
class ClassifierMIL(nn.Module):
    
    def __init__(self):
        super(ClassifierMIL, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
#         self.classifier = nn.Linear(1024, 1)
        
    def forward(self, x):
        # x is BxNxF
        return torch.sigmoid(torch.mean(self.classifier(x)))

In [156]:
model = ClassifierMIL()
model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fun = nn.BCELoss()
accum = 4

In [152]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
all_accs = []
all_losses = []
for epoch in range(1000):
    accs = []
    losses =[]
    p_bar = tqdm(total=len(train_y))
    counter = 0
    ac_loss = 0
    inds = np.arange(0, len(train_x))
    np.random.shuffle(inds)
    for ind in inds:
        x_data = train_x[ind]
        y_data = train_y[ind]
        x_in = torch.tensor(x_data).view(1, -1, 1024).cuda()
        x_in = (x_in - 0.08)/ 0.021013880255015563
        y = torch.tensor(y_data).cuda()
        out = model(x_in)
        loss = loss_fun(out, y.float())
        ac_loss += loss
        acc = int(round(out.item()) == y_data)
        accs.append(acc)
        losses.append(float(loss))
        p_bar.set_description(f"loss = {np.mean(losses):3f}, acc = {np.mean(accs):.2f}")
        p_bar.update()
        counter += 1
        if counter % accum == 0:
            ac_loss /= counter
            optim.zero_grad()
            ac_loss.backward()
            optim.step()
            counter = 0
            ac_loss = 0
            
    if counter != 0:
        ac_loss /= counter
        optim.zero_grad()
        ac_loss.backward()
        optim.step()
        counter = 0
        ac_loss = 0
    print(f"Epoch {epoch}, Loss = {np.mean(losses)}, Acc = {np.mean(accs)}")
    all_accs.append(np.mean(accs))
    all_losses.append(np.mean(losses))

HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 0, Loss = 0.9887542265757642, Acc = 0.5833333333333334


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 1, Loss = 0.637896426810095, Acc = 0.6538461538461539


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 2, Loss = 0.5873449708884343, Acc = 0.6923076923076923


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 3, Loss = 0.5733714470257744, Acc = 0.7051282051282052




HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 4, Loss = 0.5583826937318708, Acc = 0.7051282051282052


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 5, Loss = 0.5652427202066741, Acc = 0.7564102564102564


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 6, Loss = 0.5242663600410407, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 7, Loss = 0.5200805358755856, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))




Epoch 8, Loss = 0.5701341324378378, Acc = 0.7307692307692307


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 9, Loss = 0.5023504417532911, Acc = 0.8076923076923077


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 10, Loss = 0.5115085247462281, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 11, Loss = 0.5031358875477543, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 12, Loss = 0.5085810222961487, Acc = 0.7692307692307693





HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 13, Loss = 0.5182853420876946, Acc = 0.7243589743589743


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 14, Loss = 0.49289088339831394, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 15, Loss = 0.5127229305903594, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 16, Loss = 0.5048644675347859, Acc = 0.7948717948717948




HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 17, Loss = 0.507615998678375, Acc = 0.7564102564102564


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 18, Loss = 0.5356619555103139, Acc = 0.7435897435897436


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 19, Loss = 0.5055592930353342, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 20, Loss = 0.5241340733825778, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))




Epoch 21, Loss = 0.4991728184753671, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 22, Loss = 0.5028930699604993, Acc = 0.7435897435897436


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 23, Loss = 0.4987212084293461, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 24, Loss = 0.5074091033714537, Acc = 0.7371794871794872


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 25, Loss = 0.5111910427983205, Acc = 0.7692307692307693





HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 26, Loss = 0.5186086431285963, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 27, Loss = 0.5004769305397088, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 28, Loss = 0.49171834083823246, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 29, Loss = 0.4936743542381849, Acc = 0.7948717948717948




HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 30, Loss = 0.49102047627541023, Acc = 0.7564102564102564


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 31, Loss = 0.5261360354981242, Acc = 0.7564102564102564


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 32, Loss = 0.511610633192154, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 33, Loss = 0.5244535881315525, Acc = 0.75


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))




Epoch 34, Loss = 0.5071151193279104, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 35, Loss = 0.49749486695658657, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 36, Loss = 0.5076771774560583, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 37, Loss = 0.4936494619364683, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 38, Loss = 0.5010702483167944, Acc = 0.7692307692307693





HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 39, Loss = 0.5148174601535385, Acc = 0.7884615384615384


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 40, Loss = 0.5032047061189914, Acc = 0.8012820512820513


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 41, Loss = 0.4923587709945889, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 42, Loss = 0.49390209392339995, Acc = 0.7564102564102564




















HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 43, Loss = 0.4923408144699911, Acc = 0.8012820512820513


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 44, Loss = 0.5028247721702195, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 45, Loss = 0.4964408792632942, Acc = 0.7564102564102564


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 46, Loss = 0.4961517222109251, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 47, Loss = 0.5001494250859534, Acc = 0.7756410256410257





HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 48, Loss = 0.49028497965460144, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 49, Loss = 0.48819248469916576, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 50, Loss = 0.49176116076668197, Acc = 0.7884615384615384


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 51, Loss = 0.4916513583022886, Acc = 0.7628205128205128




HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 52, Loss = 0.48913447859320935, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 53, Loss = 0.49472555189874046, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 54, Loss = 0.4990802876925908, Acc = 0.7884615384615384


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 55, Loss = 0.49152499240023106, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 56, Loss = 0.49020618939091665, Acc = 0.782051282051282





HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 57, Loss = 0.5029402528740227, Acc = 0.8076923076923077


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 58, Loss = 0.4991597637343101, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 59, Loss = 0.49774005612692773, Acc = 0.7435897435897436


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 60, Loss = 0.48727404453212586, Acc = 0.75




HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 61, Loss = 0.4915224551118743, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 62, Loss = 0.4889747086551954, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 63, Loss = 0.5312102480645243, Acc = 0.7435897435897436


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 64, Loss = 0.49091147657376355, Acc = 0.7564102564102564


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))




Epoch 65, Loss = 0.4899186033901806, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 66, Loss = 0.48844225605940805, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 67, Loss = 0.4879612450104338, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 68, Loss = 0.489892484346339, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 69, Loss = 0.4924318903269103, Acc = 0.7884615384615384





HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 70, Loss = 0.4936568519497553, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 71, Loss = 0.4922739180381028, Acc = 0.7564102564102564


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 72, Loss = 0.4911131987084026, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 73, Loss = 0.48817504508546194, Acc = 0.7884615384615384




HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 74, Loss = 0.48815225914586335, Acc = 0.7564102564102564


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 75, Loss = 0.4850635104484331, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 76, Loss = 0.48093847378652593, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 77, Loss = 0.49955180373404007, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))




Epoch 78, Loss = 0.4981074550021917, Acc = 0.7307692307692307


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 79, Loss = 0.48316724180581216, Acc = 0.7884615384615384


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 80, Loss = 0.5014060393083267, Acc = 0.7564102564102564


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 81, Loss = 0.48871883214600625, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 82, Loss = 0.4890073939951925, Acc = 0.7628205128205128





HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 83, Loss = 0.4937124237177882, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 84, Loss = 0.5027631376265512, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 85, Loss = 0.4897152808635161, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 86, Loss = 0.4836509358585597, Acc = 0.7692307692307693




HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 87, Loss = 0.48318568254874733, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 88, Loss = 0.48085739845201037, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 89, Loss = 0.5005350948571831, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 90, Loss = 0.4901737698461287, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))




Epoch 91, Loss = 0.48635789179108824, Acc = 0.7884615384615384


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 92, Loss = 0.4856351229738301, Acc = 0.7948717948717948


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 93, Loss = 0.48477302856134397, Acc = 0.7692307692307693


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 94, Loss = 0.4888069105785316, Acc = 0.782051282051282


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 95, Loss = 0.4789616402928144, Acc = 0.782051282051282
























HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 96, Loss = 0.47913604791168696, Acc = 0.7884615384615384


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 97, Loss = 0.4807984949789571, Acc = 0.7884615384615384


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 98, Loss = 0.48037882999088677, Acc = 0.7628205128205128


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

Epoch 99, Loss = 0.47920654249416356, Acc = 0.7756410256410257


HBox(children=(FloatProgress(value=0.0, max=156.0), HTML(value='')))

In [132]:
out

tensor(0.0793, device='cuda:0', grad_fn=<SigmoidBackward>)

In [None]:
plt.plot(all_accs)

In [None]:
plt.plot(all_losses[100:])

In [141]:
inds = np.arange(0, len(val_x))
np.random.shuffle(inds)
accs = []
for ind in inds:
    x_data = val_x[ind]
    y_data = val_y[ind]
    x_in = torch.tensor(x_data).view(1, -1, 1024).cuda()
    x_in = (x_in - 0.08)/ 0.021013880255015563
    y = torch.tensor(y_data).cuda()
    out = model(x_in)
    loss = loss_fun(out, y.float())
    ac_loss += loss
    acc = int(round(out.item()) == y_data)
    accs.append(acc)

In [142]:
np.mean(accs)

0.7948717948717948

In [32]:
torch.tensor(train_x[0:8]).shape

ValueError: expected sequence of length 142 at dim 1 (got 129)

In [33]:
out = model(torch.tensor(train_x[0:1]))

In [35]:
out

tensor([[-0.4763]], grad_fn=<AddmmBackward>)

In [9]:
learned_vec = torch.zeros((1,1024))

In [16]:
r = learned_vec.view(1, 1, -1).repeat(3,1,1)

In [17]:
r.shape

torch.Size([3, 1, 1024])

In [20]:
torch.cat([r,r], axis=1).shape

torch.Size([3, 2, 1024])