In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import models, transforms
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import gc
import os 

# clear cuda memory and collect garbage
gc.collect()
torch.cuda.empty_cache()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device: ',device)

x=np.load('scaled_spec_resampled_array.npy')
y=np.load('labels_array.npy')-1
x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])

print(x.shape, y.shape)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

train_dataset = MyDataset(x_train, y_train)
test_dataset = MyDataset(x_test, y_test)


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Device:  cuda:0
(1754, 1, 2048, 80) (1754,)


In [2]:
resnet34 = models.resnet34(pretrained=True)
resnet34.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet34.fc = nn.Linear(in_features=512, out_features=6, bias=True)    


#check for checkpoint save and load resnet34.pth
if os.path.exists('resnet34.pth'):
    resnet34.load_state_dict(torch.load('resnet34.pth'))
    print('Model loaded')
else:
    print('No model found, loading pretrained model')


resnet34 = resnet34.to(device)

#freeze every layer except the first conv1 and last fc layers of resnet50
'''for name, param in resnet34.named_parameters():
    if name not in ['conv1.weight', 'fc.weight', 'fc.bias']:
        param.requires_grad = False
    else:
        param.requires_grad = True
'''



No model found, loading pretrained model


"for name, param in resnet34.named_parameters():\n    if name not in ['conv1.weight', 'fc.weight', 'fc.bias']:\n        param.requires_grad = False\n    else:\n        param.requires_grad = True\n"

In [3]:
def test(model, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss = running_loss / len(test_loader)
    val_acc = correct / total
    return val_loss, val_acc

def train(model, train_loader, test_loader, criterion, optimizer, epochs):
    model.train()
    running_loss = 0.0
    epoch_bar = tqdm(range(epochs), position=0)
    for epoch in epoch_bar:
        batch_bar=tqdm(enumerate(train_loader, 0), total=len(train_loader), position=1, leave=False)
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            batch_bar.set_description('Train loss: %.3f' % (loss.item()))
        train_loss = running_loss / len(train_loader)
        epoch_bar.set_description('Train loss: %.3f' % train_loss)
        val_loss, val_acc = test(model, test_loader, criterion)
        print('Epoch: %d, Train Loss: %.3f, Val Loss: %.3f, Val Acc: %.3f' % (epoch, train_loss, val_loss, val_acc))
    return model

In [4]:
optimizer = optim.Adam(resnet34.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

resnet34=train(resnet34, train_loader,test_loader, criterion, optimizer, 100)

torch.save(resnet34.state_dict(), 'resnet34.pth')


Train loss: 1.206:   1%|          | 1/100 [00:24<40:59, 24.84s/it]

Epoch: 0, Train Loss: 1.206, Val Loss: 14.254, Val Acc: 0.279


Train loss: 29.489:   2%|▏         | 2/100 [00:44<35:29, 21.73s/it]

Epoch: 1, Train Loss: 29.489, Val Loss: 1.894, Val Acc: 0.154


Train loss: 31.285:   3%|▎         | 3/100 [01:06<35:33, 21.99s/it]

Epoch: 2, Train Loss: 31.285, Val Loss: 1.775, Val Acc: 0.154


Train loss: 33.058:   4%|▍         | 4/100 [01:29<35:53, 22.44s/it]

Epoch: 3, Train Loss: 33.058, Val Loss: 1.828, Val Acc: 0.185


Train loss: 34.853:   5%|▌         | 5/100 [01:52<35:48, 22.61s/it]

Epoch: 4, Train Loss: 34.853, Val Loss: 1.779, Val Acc: 0.154


Train loss: 36.619:   6%|▌         | 6/100 [02:15<35:39, 22.76s/it]

Epoch: 5, Train Loss: 36.619, Val Loss: 1.747, Val Acc: 0.225


Train loss: 38.383:   7%|▋         | 7/100 [02:38<35:25, 22.85s/it]

Epoch: 6, Train Loss: 38.383, Val Loss: 1.748, Val Acc: 0.225


Train loss: 40.103:   8%|▊         | 8/100 [03:01<35:08, 22.92s/it]

Epoch: 7, Train Loss: 40.103, Val Loss: 1.675, Val Acc: 0.276


Train loss: 41.906:   9%|▉         | 9/100 [03:24<34:49, 22.96s/it]

Epoch: 8, Train Loss: 41.906, Val Loss: 1.725, Val Acc: 0.222


Train loss: 43.628:  10%|█         | 10/100 [03:48<34:32, 23.02s/it]

Epoch: 9, Train Loss: 43.628, Val Loss: 1.685, Val Acc: 0.268


Train loss: 45.312:  11%|█         | 11/100 [04:11<34:13, 23.07s/it]

Epoch: 10, Train Loss: 45.312, Val Loss: 1.709, Val Acc: 0.276


Train loss: 46.979:  12%|█▏        | 12/100 [04:34<33:52, 23.09s/it]

Epoch: 11, Train Loss: 46.979, Val Loss: 1.571, Val Acc: 0.385


Train loss: 48.433:  13%|█▎        | 13/100 [04:57<33:30, 23.11s/it]

Epoch: 12, Train Loss: 48.433, Val Loss: 1.435, Val Acc: 0.407


Train loss: 49.737:  14%|█▍        | 14/100 [05:20<33:08, 23.13s/it]

Epoch: 13, Train Loss: 49.737, Val Loss: 1.285, Val Acc: 0.439


Train loss: 51.038:  15%|█▌        | 15/100 [05:43<32:46, 23.13s/it]

Epoch: 14, Train Loss: 51.038, Val Loss: 1.236, Val Acc: 0.487


Train loss: 52.344:  16%|█▌        | 16/100 [06:07<32:23, 23.13s/it]

Epoch: 15, Train Loss: 52.344, Val Loss: 1.221, Val Acc: 0.430


Train loss: 53.598:  17%|█▋        | 17/100 [06:30<32:01, 23.15s/it]

Epoch: 16, Train Loss: 53.598, Val Loss: 1.237, Val Acc: 0.467


Train loss: 54.817:  18%|█▊        | 18/100 [06:53<31:37, 23.15s/it]

Epoch: 17, Train Loss: 54.817, Val Loss: 1.142, Val Acc: 0.504


Train loss: 55.915:  19%|█▉        | 19/100 [07:16<31:13, 23.13s/it]

Epoch: 18, Train Loss: 55.915, Val Loss: 1.027, Val Acc: 0.558


Train loss: 57.159:  20%|██        | 20/100 [07:39<30:50, 23.13s/it]

Epoch: 19, Train Loss: 57.159, Val Loss: 1.008, Val Acc: 0.587


Train loss: 58.153:  21%|██        | 21/100 [08:02<30:26, 23.12s/it]

Epoch: 20, Train Loss: 58.153, Val Loss: 0.977, Val Acc: 0.561


Train loss: 59.028:  22%|██▏       | 22/100 [08:25<30:03, 23.13s/it]

Epoch: 21, Train Loss: 59.028, Val Loss: 0.767, Val Acc: 0.715


Train loss: 59.851:  23%|██▎       | 23/100 [08:49<29:42, 23.15s/it]

Epoch: 22, Train Loss: 59.851, Val Loss: 0.855, Val Acc: 0.672


Train loss: 60.600:  24%|██▍       | 24/100 [09:12<29:18, 23.14s/it]

Epoch: 23, Train Loss: 60.600, Val Loss: 0.627, Val Acc: 0.772


Train loss: 61.251:  25%|██▌       | 25/100 [09:35<28:49, 23.06s/it]

Epoch: 24, Train Loss: 61.251, Val Loss: 0.511, Val Acc: 0.781


Train loss: 61.871:  26%|██▌       | 26/100 [09:55<27:34, 22.36s/it]

Epoch: 25, Train Loss: 61.871, Val Loss: 0.508, Val Acc: 0.786


Train loss: 62.475:  27%|██▋       | 27/100 [10:17<26:52, 22.08s/it]

Epoch: 26, Train Loss: 62.475, Val Loss: 0.567, Val Acc: 0.803


Train loss: 63.016:  28%|██▊       | 28/100 [10:40<26:51, 22.38s/it]

Epoch: 27, Train Loss: 63.016, Val Loss: 0.455, Val Acc: 0.815


Train loss: 63.432:  29%|██▉       | 29/100 [11:03<26:42, 22.57s/it]

Epoch: 28, Train Loss: 63.432, Val Loss: 0.552, Val Acc: 0.789


Train loss: 63.896:  30%|███       | 30/100 [11:26<26:30, 22.72s/it]

Epoch: 29, Train Loss: 63.896, Val Loss: 0.463, Val Acc: 0.821


Train loss: 64.358:  31%|███       | 31/100 [11:49<26:15, 22.83s/it]

Epoch: 30, Train Loss: 64.358, Val Loss: 0.380, Val Acc: 0.858


Train loss: 64.741:  32%|███▏      | 32/100 [12:12<25:56, 22.90s/it]

Epoch: 31, Train Loss: 64.741, Val Loss: 0.322, Val Acc: 0.869


Train loss: 65.086:  33%|███▎      | 33/100 [12:35<25:33, 22.89s/it]

Epoch: 32, Train Loss: 65.086, Val Loss: 0.332, Val Acc: 0.855


Train loss: 65.388:  34%|███▍      | 34/100 [12:58<25:10, 22.89s/it]

Epoch: 33, Train Loss: 65.388, Val Loss: 0.315, Val Acc: 0.860


Train loss: 65.707:  35%|███▌      | 35/100 [13:21<24:48, 22.90s/it]

Epoch: 34, Train Loss: 65.707, Val Loss: 0.306, Val Acc: 0.858


Train loss: 66.018:  36%|███▌      | 36/100 [13:44<24:25, 22.90s/it]

Epoch: 35, Train Loss: 66.018, Val Loss: 0.251, Val Acc: 0.897


Train loss: 66.302:  37%|███▋      | 37/100 [14:06<24:02, 22.90s/it]

Epoch: 36, Train Loss: 66.302, Val Loss: 0.323, Val Acc: 0.872


Train loss: 66.556:  38%|███▊      | 38/100 [14:29<23:41, 22.93s/it]

Epoch: 37, Train Loss: 66.556, Val Loss: 0.287, Val Acc: 0.872


Train loss: 66.781:  39%|███▉      | 39/100 [14:52<23:20, 22.95s/it]

Epoch: 38, Train Loss: 66.781, Val Loss: 0.230, Val Acc: 0.906


Train loss: 66.985:  40%|████      | 40/100 [15:16<23:00, 23.00s/it]

Epoch: 39, Train Loss: 66.985, Val Loss: 0.314, Val Acc: 0.855


Train loss: 67.212:  41%|████      | 41/100 [15:39<22:37, 23.01s/it]

Epoch: 40, Train Loss: 67.212, Val Loss: 0.321, Val Acc: 0.886


Train loss: 67.428:  42%|████▏     | 42/100 [16:01<21:56, 22.69s/it]

Epoch: 41, Train Loss: 67.428, Val Loss: 0.434, Val Acc: 0.823


Train loss: 67.670:  43%|████▎     | 43/100 [16:21<20:58, 22.08s/it]

Epoch: 42, Train Loss: 67.670, Val Loss: 0.302, Val Acc: 0.880


Train loss: 67.888:  44%|████▍     | 44/100 [16:43<20:27, 21.92s/it]

Epoch: 43, Train Loss: 67.888, Val Loss: 0.267, Val Acc: 0.900


Train loss: 68.066:  45%|████▌     | 45/100 [17:04<20:00, 21.82s/it]

Epoch: 44, Train Loss: 68.066, Val Loss: 0.244, Val Acc: 0.900


Train loss: 68.257:  46%|████▌     | 46/100 [17:27<19:46, 21.98s/it]

Epoch: 45, Train Loss: 68.257, Val Loss: 0.296, Val Acc: 0.866


Train loss: 68.473:  47%|████▋     | 47/100 [17:48<19:17, 21.84s/it]

Epoch: 46, Train Loss: 68.473, Val Loss: 0.258, Val Acc: 0.897


Train loss: 68.793:  48%|████▊     | 48/100 [18:11<19:03, 21.99s/it]

Epoch: 47, Train Loss: 68.793, Val Loss: 0.372, Val Acc: 0.855


Train loss: 69.081:  49%|████▉     | 49/100 [18:31<18:23, 21.64s/it]

Epoch: 48, Train Loss: 69.081, Val Loss: 0.278, Val Acc: 0.886


Train loss: 69.298:  50%|█████     | 50/100 [18:53<18:00, 21.60s/it]

Epoch: 49, Train Loss: 69.298, Val Loss: 0.492, Val Acc: 0.815


Train loss: 69.505:  51%|█████     | 51/100 [19:13<17:18, 21.19s/it]

Epoch: 50, Train Loss: 69.505, Val Loss: 0.302, Val Acc: 0.866


Train loss: 69.677:  52%|█████▏    | 52/100 [19:33<16:44, 20.93s/it]

Epoch: 51, Train Loss: 69.677, Val Loss: 0.288, Val Acc: 0.900


Train loss: 69.798:  53%|█████▎    | 53/100 [19:54<16:15, 20.75s/it]

Epoch: 52, Train Loss: 69.798, Val Loss: 0.337, Val Acc: 0.903


Train loss: 69.905:  54%|█████▍    | 54/100 [20:15<15:57, 20.81s/it]

Epoch: 53, Train Loss: 69.905, Val Loss: 0.308, Val Acc: 0.883


Train loss: 70.019:  55%|█████▌    | 55/100 [20:35<15:30, 20.68s/it]

Epoch: 54, Train Loss: 70.019, Val Loss: 0.248, Val Acc: 0.906


Train loss: 70.163:  56%|█████▌    | 56/100 [20:55<15:04, 20.56s/it]

Epoch: 55, Train Loss: 70.163, Val Loss: 0.302, Val Acc: 0.906


Train loss: 70.355:  57%|█████▋    | 57/100 [21:16<14:40, 20.48s/it]

Epoch: 56, Train Loss: 70.355, Val Loss: 0.338, Val Acc: 0.886


Train loss: 70.544:  58%|█████▊    | 58/100 [21:36<14:17, 20.42s/it]

Epoch: 57, Train Loss: 70.544, Val Loss: 0.400, Val Acc: 0.838


Train loss: 70.735:  59%|█████▉    | 59/100 [21:56<13:55, 20.37s/it]

Epoch: 58, Train Loss: 70.735, Val Loss: 0.201, Val Acc: 0.917


Train loss: 70.867:  60%|██████    | 60/100 [22:16<13:33, 20.34s/it]

Epoch: 59, Train Loss: 70.867, Val Loss: 0.249, Val Acc: 0.906


Train loss: 71.004:  61%|██████    | 61/100 [22:37<13:12, 20.32s/it]

Epoch: 60, Train Loss: 71.004, Val Loss: 0.319, Val Acc: 0.880


Train loss: 71.134:  62%|██████▏   | 62/100 [22:57<12:51, 20.31s/it]

Epoch: 61, Train Loss: 71.134, Val Loss: 0.323, Val Acc: 0.886


Train loss: 71.240:  63%|██████▎   | 63/100 [23:17<12:31, 20.30s/it]

Epoch: 62, Train Loss: 71.240, Val Loss: 0.270, Val Acc: 0.906


Train loss: 71.335:  64%|██████▍   | 64/100 [23:38<12:10, 20.29s/it]

Epoch: 63, Train Loss: 71.335, Val Loss: 0.220, Val Acc: 0.934


Train loss: 71.422:  65%|██████▌   | 65/100 [23:58<11:50, 20.29s/it]

Epoch: 64, Train Loss: 71.422, Val Loss: 0.257, Val Acc: 0.895


Train loss: 71.526:  66%|██████▌   | 66/100 [24:18<11:29, 20.28s/it]

Epoch: 65, Train Loss: 71.526, Val Loss: 0.293, Val Acc: 0.892


Train loss: 71.687:  67%|██████▋   | 67/100 [24:38<11:09, 20.29s/it]

Epoch: 66, Train Loss: 71.687, Val Loss: 0.276, Val Acc: 0.895


Train loss: 71.872:  68%|██████▊   | 68/100 [24:59<10:49, 20.28s/it]

Epoch: 67, Train Loss: 71.872, Val Loss: 0.292, Val Acc: 0.909


Train loss: 72.024:  69%|██████▉   | 69/100 [25:19<10:28, 20.29s/it]

Epoch: 68, Train Loss: 72.024, Val Loss: 0.227, Val Acc: 0.912


Train loss: 72.149:  70%|███████   | 70/100 [25:39<10:08, 20.28s/it]

Epoch: 69, Train Loss: 72.149, Val Loss: 0.296, Val Acc: 0.900


Train loss: 72.259:  71%|███████   | 71/100 [26:00<09:48, 20.29s/it]

Epoch: 70, Train Loss: 72.259, Val Loss: 0.276, Val Acc: 0.895


Train loss: 72.331:  72%|███████▏  | 72/100 [26:20<09:27, 20.28s/it]

Epoch: 71, Train Loss: 72.331, Val Loss: 0.222, Val Acc: 0.926


Train loss: 72.402:  73%|███████▎  | 73/100 [26:40<09:07, 20.28s/it]

Epoch: 72, Train Loss: 72.402, Val Loss: 0.251, Val Acc: 0.912


Train loss: 72.462:  74%|███████▍  | 74/100 [27:00<08:47, 20.28s/it]

Epoch: 73, Train Loss: 72.462, Val Loss: 0.310, Val Acc: 0.906


Train loss: 72.553:  75%|███████▌  | 75/100 [27:21<08:26, 20.28s/it]

Epoch: 74, Train Loss: 72.553, Val Loss: 0.283, Val Acc: 0.892


Train loss: 72.650:  76%|███████▌  | 76/100 [27:41<08:06, 20.28s/it]

Epoch: 75, Train Loss: 72.650, Val Loss: 0.477, Val Acc: 0.846


Train loss: 72.745:  77%|███████▋  | 77/100 [28:01<07:46, 20.28s/it]

Epoch: 76, Train Loss: 72.745, Val Loss: 0.312, Val Acc: 0.912


Train loss: 72.795:  78%|███████▊  | 78/100 [28:22<07:26, 20.27s/it]

Epoch: 77, Train Loss: 72.795, Val Loss: 0.327, Val Acc: 0.903


Train loss: 72.858:  79%|███████▉  | 79/100 [28:42<07:05, 20.28s/it]

Epoch: 78, Train Loss: 72.858, Val Loss: 0.311, Val Acc: 0.909


Train loss: 72.903:  80%|████████  | 80/100 [29:02<06:45, 20.28s/it]

Epoch: 79, Train Loss: 72.903, Val Loss: 0.378, Val Acc: 0.900


Train loss: 72.997:  81%|████████  | 81/100 [29:22<06:25, 20.28s/it]

Epoch: 80, Train Loss: 72.997, Val Loss: 0.507, Val Acc: 0.838


Train loss: 73.123:  82%|████████▏ | 82/100 [29:43<06:04, 20.27s/it]

Epoch: 81, Train Loss: 73.123, Val Loss: 0.252, Val Acc: 0.903


Train loss: 73.209:  83%|████████▎ | 83/100 [30:03<05:44, 20.28s/it]

Epoch: 82, Train Loss: 73.209, Val Loss: 0.285, Val Acc: 0.900


Train loss: 73.293:  84%|████████▍ | 84/100 [30:23<05:24, 20.27s/it]

Epoch: 83, Train Loss: 73.293, Val Loss: 0.335, Val Acc: 0.892


Train loss: 73.426:  85%|████████▌ | 85/100 [30:43<05:03, 20.25s/it]

Epoch: 84, Train Loss: 73.426, Val Loss: 0.455, Val Acc: 0.852


Train loss: 73.542:  86%|████████▌ | 86/100 [31:04<04:43, 20.23s/it]

Epoch: 85, Train Loss: 73.542, Val Loss: 0.279, Val Acc: 0.903


Train loss: 73.612:  87%|████████▋ | 87/100 [31:24<04:22, 20.22s/it]

Epoch: 86, Train Loss: 73.612, Val Loss: 0.394, Val Acc: 0.875


Train loss: 73.824:  88%|████████▊ | 88/100 [31:44<04:02, 20.21s/it]

Epoch: 87, Train Loss: 73.824, Val Loss: 0.361, Val Acc: 0.869


Train loss: 73.964:  89%|████████▉ | 89/100 [32:06<03:49, 20.88s/it]

Epoch: 88, Train Loss: 73.964, Val Loss: 0.264, Val Acc: 0.883


Train loss: 74.041:  90%|█████████ | 90/100 [32:28<03:30, 21.09s/it]

Epoch: 89, Train Loss: 74.041, Val Loss: 0.463, Val Acc: 0.860


Train loss: 74.169:  91%|█████████ | 91/100 [32:49<03:11, 21.22s/it]

Epoch: 90, Train Loss: 74.169, Val Loss: 0.565, Val Acc: 0.855


Train loss: 74.274:  92%|█████████▏| 92/100 [33:11<02:50, 21.32s/it]

Epoch: 91, Train Loss: 74.274, Val Loss: 0.315, Val Acc: 0.903


Train loss: 74.340:  93%|█████████▎| 93/100 [33:33<02:30, 21.43s/it]

Epoch: 92, Train Loss: 74.340, Val Loss: 0.333, Val Acc: 0.889


Train loss: 74.388:  94%|█████████▍| 94/100 [33:56<02:11, 21.95s/it]

Epoch: 93, Train Loss: 74.388, Val Loss: 0.376, Val Acc: 0.900


Train loss: 74.423:  95%|█████████▌| 95/100 [34:19<01:51, 22.32s/it]

Epoch: 94, Train Loss: 74.423, Val Loss: 0.276, Val Acc: 0.915


Train loss: 74.432:  96%|█████████▌| 96/100 [34:42<01:29, 22.50s/it]

Epoch: 95, Train Loss: 74.432, Val Loss: 0.535, Val Acc: 0.900


Train loss: 74.458:  97%|█████████▋| 97/100 [35:05<01:08, 22.72s/it]

Epoch: 96, Train Loss: 74.458, Val Loss: 0.600, Val Acc: 0.889


Train loss: 74.499:  98%|█████████▊| 98/100 [35:28<00:45, 22.78s/it]

Epoch: 97, Train Loss: 74.499, Val Loss: 0.402, Val Acc: 0.883


Train loss: 74.524:  99%|█████████▉| 99/100 [35:51<00:22, 22.91s/it]

Epoch: 98, Train Loss: 74.524, Val Loss: 0.394, Val Acc: 0.895


Train loss: 74.532: 100%|██████████| 100/100 [36:14<00:00, 21.75s/it]


Epoch: 99, Train Loss: 74.532, Val Loss: 0.623, Val Acc: 0.900
