In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision import models, transforms
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import gc #garbage collector

# clear cuda memory and collect garbage
gc.collect()
torch.cuda.empty_cache()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device: ',device)

x=np.load('scaled_spec_resampled_array.npy')
y=np.load('labels_array.npy')-1
x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])

print(x.shape, y.shape)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

train_dataset = MyDataset(x_train, y_train)
test_dataset = MyDataset(x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Device:  cuda:0
(1754, 1, 2048, 80) (1754,)


In [2]:
resnet18 = models.resnet18(pretrained=True)
resnet18.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet18.fc = nn.Linear(in_features=512, out_features=6, bias=True)    

resnet18 = resnet18.to(device)

#freeze every layer except the first conv1 layer and the last fc layer, to adapt the model to our data
for name, param in resnet18.named_parameters():
    if name not in ['conv1.weight', 'fc.weight', 'fc.bias']:
        param.requires_grad = False
    else:
        param.requires_grad = True



In [3]:
def test(model, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss = running_loss / len(test_loader)
    val_acc = correct / total
    return val_loss, val_acc

def train(model, train_loader, test_loader, criterion, optimizer, epochs):
    model.train()
    running_loss = 0.0
    epoch_bar = tqdm(range(epochs), position=0)
    for epoch in epoch_bar:
        batch_bar=tqdm(enumerate(train_loader, 0), total=len(train_loader), position=1, leave=False)
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            batch_bar.set_description('Train loss: %.3f' % (loss.item()))
        train_loss = running_loss / len(train_loader)
        epoch_bar.set_description('Train loss: %.3f' % train_loss)
        val_loss, val_acc = test(model, test_loader, criterion)
        print('Epoch: %d, Train Loss: %.3f, Val Loss: %.3f, Val Acc: %.3f' % (epoch, train_loss, val_loss, val_acc))
    return model


In [4]:
optimizer = optim.Adam(resnet18.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

resnet18=train(resnet18, train_loader,test_loader, criterion, optimizer, 100)

torch.save(resnet18.state_dict(), 'resnet18.pth')

Train loss: 1.678:   1%|          | 1/100 [00:25<41:21, 25.07s/it]

Epoch: 0, Train Loss: 1.678, Val Loss: 1.766, Val Acc: 0.225


Train loss: 3.183:   2%|▏         | 2/100 [00:33<25:01, 15.32s/it]

Epoch: 1, Train Loss: 3.183, Val Loss: 1.260, Val Acc: 0.462


Train loss: 4.450:   3%|▎         | 3/100 [00:42<19:51, 12.28s/it]

Epoch: 2, Train Loss: 4.450, Val Loss: 1.166, Val Acc: 0.447


Train loss: 5.692:   4%|▍         | 4/100 [00:50<17:21, 10.85s/it]

Epoch: 3, Train Loss: 5.692, Val Loss: 1.271, Val Acc: 0.467


Train loss: 6.923:   5%|▌         | 5/100 [00:59<15:55, 10.06s/it]

Epoch: 4, Train Loss: 6.923, Val Loss: 1.149, Val Acc: 0.484


Train loss: 8.125:   6%|▌         | 6/100 [01:08<15:03,  9.61s/it]

Epoch: 5, Train Loss: 8.125, Val Loss: 1.144, Val Acc: 0.507


Train loss: 9.294:   7%|▋         | 7/100 [01:17<14:34,  9.41s/it]

Epoch: 6, Train Loss: 9.294, Val Loss: 1.142, Val Acc: 0.484


Train loss: 10.460:   8%|▊         | 8/100 [01:27<14:35,  9.52s/it]

Epoch: 7, Train Loss: 10.460, Val Loss: 1.112, Val Acc: 0.507


Train loss: 11.604:   9%|▉         | 9/100 [01:36<14:35,  9.62s/it]

Epoch: 8, Train Loss: 11.604, Val Loss: 1.061, Val Acc: 0.613


Train loss: 12.709:  10%|█         | 10/100 [01:46<14:32,  9.69s/it]

Epoch: 9, Train Loss: 12.709, Val Loss: 1.064, Val Acc: 0.570


Train loss: 13.811:  11%|█         | 11/100 [01:56<14:28,  9.76s/it]

Epoch: 10, Train Loss: 13.811, Val Loss: 1.088, Val Acc: 0.564


Train loss: 14.891:  12%|█▏        | 12/100 [02:06<14:23,  9.82s/it]

Epoch: 11, Train Loss: 14.891, Val Loss: 1.178, Val Acc: 0.507


Train loss: 15.949:  13%|█▎        | 13/100 [02:16<14:18,  9.86s/it]

Epoch: 12, Train Loss: 15.949, Val Loss: 1.002, Val Acc: 0.575


Train loss: 16.981:  14%|█▍        | 14/100 [02:26<14:09,  9.88s/it]

Epoch: 13, Train Loss: 16.981, Val Loss: 1.005, Val Acc: 0.618


Train loss: 18.009:  15%|█▌        | 15/100 [02:36<14:01,  9.90s/it]

Epoch: 14, Train Loss: 18.009, Val Loss: 0.987, Val Acc: 0.598


Train loss: 19.027:  16%|█▌        | 16/100 [02:46<13:53,  9.93s/it]

Epoch: 15, Train Loss: 19.027, Val Loss: 0.986, Val Acc: 0.601


Train loss: 20.008:  17%|█▋        | 17/100 [02:56<13:43,  9.92s/it]

Epoch: 16, Train Loss: 20.008, Val Loss: 0.994, Val Acc: 0.601


Train loss: 20.991:  18%|█▊        | 18/100 [03:06<13:36,  9.95s/it]

Epoch: 17, Train Loss: 20.991, Val Loss: 0.948, Val Acc: 0.650


Train loss: 21.943:  19%|█▉        | 19/100 [03:16<13:25,  9.94s/it]

Epoch: 18, Train Loss: 21.943, Val Loss: 0.934, Val Acc: 0.647


Train loss: 22.876:  20%|██        | 20/100 [03:26<13:15,  9.94s/it]

Epoch: 19, Train Loss: 22.876, Val Loss: 0.933, Val Acc: 0.632


Train loss: 23.780:  21%|██        | 21/100 [03:36<13:04,  9.93s/it]

Epoch: 20, Train Loss: 23.780, Val Loss: 0.906, Val Acc: 0.678


Train loss: 24.674:  22%|██▏       | 22/100 [03:46<12:55,  9.94s/it]

Epoch: 21, Train Loss: 24.674, Val Loss: 0.898, Val Acc: 0.675


Train loss: 25.551:  23%|██▎       | 23/100 [03:56<12:45,  9.94s/it]

Epoch: 22, Train Loss: 25.551, Val Loss: 0.913, Val Acc: 0.635


Train loss: 26.444:  24%|██▍       | 24/100 [04:05<12:34,  9.93s/it]

Epoch: 23, Train Loss: 26.444, Val Loss: 0.850, Val Acc: 0.709


Train loss: 27.298:  25%|██▌       | 25/100 [04:15<12:24,  9.93s/it]

Epoch: 24, Train Loss: 27.298, Val Loss: 0.834, Val Acc: 0.724


Train loss: 28.126:  26%|██▌       | 26/100 [04:25<12:14,  9.93s/it]

Epoch: 25, Train Loss: 28.126, Val Loss: 0.985, Val Acc: 0.610


Train loss: 28.994:  27%|██▋       | 27/100 [04:35<12:04,  9.92s/it]

Epoch: 26, Train Loss: 28.994, Val Loss: 0.874, Val Acc: 0.695


Train loss: 29.810:  28%|██▊       | 28/100 [04:45<11:54,  9.93s/it]

Epoch: 27, Train Loss: 29.810, Val Loss: 0.790, Val Acc: 0.755


Train loss: 30.593:  29%|██▉       | 29/100 [04:55<11:44,  9.92s/it]

Epoch: 28, Train Loss: 30.593, Val Loss: 0.780, Val Acc: 0.735


Train loss: 31.419:  30%|███       | 30/100 [05:05<11:34,  9.92s/it]

Epoch: 29, Train Loss: 31.419, Val Loss: 0.847, Val Acc: 0.684


Train loss: 32.247:  31%|███       | 31/100 [05:15<11:24,  9.92s/it]

Epoch: 30, Train Loss: 32.247, Val Loss: 0.820, Val Acc: 0.721


Train loss: 33.039:  32%|███▏      | 32/100 [05:25<11:13,  9.91s/it]

Epoch: 31, Train Loss: 33.039, Val Loss: 0.766, Val Acc: 0.726


Train loss: 33.802:  33%|███▎      | 33/100 [05:35<11:04,  9.92s/it]

Epoch: 32, Train Loss: 33.802, Val Loss: 0.863, Val Acc: 0.712


Train loss: 34.571:  34%|███▍      | 34/100 [05:45<10:54,  9.92s/it]

Epoch: 33, Train Loss: 34.571, Val Loss: 0.783, Val Acc: 0.707


Train loss: 35.327:  35%|███▌      | 35/100 [05:55<10:45,  9.92s/it]

Epoch: 34, Train Loss: 35.327, Val Loss: 0.756, Val Acc: 0.769


Train loss: 36.090:  36%|███▌      | 36/100 [06:04<10:34,  9.91s/it]

Epoch: 35, Train Loss: 36.090, Val Loss: 0.731, Val Acc: 0.752


Train loss: 36.839:  37%|███▋      | 37/100 [06:14<10:24,  9.91s/it]

Epoch: 36, Train Loss: 36.839, Val Loss: 0.775, Val Acc: 0.718


Train loss: 37.562:  38%|███▊      | 38/100 [06:24<10:14,  9.92s/it]

Epoch: 37, Train Loss: 37.562, Val Loss: 0.737, Val Acc: 0.732


Train loss: 38.288:  39%|███▉      | 39/100 [06:34<10:04,  9.91s/it]

Epoch: 38, Train Loss: 38.288, Val Loss: 0.736, Val Acc: 0.738


Train loss: 38.996:  40%|████      | 40/100 [06:44<09:54,  9.91s/it]

Epoch: 39, Train Loss: 38.996, Val Loss: 0.713, Val Acc: 0.755


Train loss: 39.701:  41%|████      | 41/100 [06:54<09:44,  9.91s/it]

Epoch: 40, Train Loss: 39.701, Val Loss: 0.701, Val Acc: 0.738


Train loss: 40.390:  42%|████▏     | 42/100 [07:04<09:34,  9.91s/it]

Epoch: 41, Train Loss: 40.390, Val Loss: 0.707, Val Acc: 0.746


Train loss: 41.053:  43%|████▎     | 43/100 [07:14<09:25,  9.92s/it]

Epoch: 42, Train Loss: 41.053, Val Loss: 0.669, Val Acc: 0.746


Train loss: 41.719:  44%|████▍     | 44/100 [07:24<09:15,  9.92s/it]

Epoch: 43, Train Loss: 41.719, Val Loss: 0.717, Val Acc: 0.721


Train loss: 42.366:  45%|████▌     | 45/100 [07:34<09:05,  9.92s/it]

Epoch: 44, Train Loss: 42.366, Val Loss: 0.649, Val Acc: 0.803


Train loss: 43.018:  46%|████▌     | 46/100 [07:44<08:55,  9.92s/it]

Epoch: 45, Train Loss: 43.018, Val Loss: 0.720, Val Acc: 0.755


Train loss: 43.665:  47%|████▋     | 47/100 [07:54<08:45,  9.92s/it]

Epoch: 46, Train Loss: 43.665, Val Loss: 0.674, Val Acc: 0.746


Train loss: 44.305:  48%|████▊     | 48/100 [08:03<08:35,  9.92s/it]

Epoch: 47, Train Loss: 44.305, Val Loss: 0.706, Val Acc: 0.758


Train loss: 44.921:  49%|████▉     | 49/100 [08:13<08:25,  9.92s/it]

Epoch: 48, Train Loss: 44.921, Val Loss: 0.687, Val Acc: 0.732


Train loss: 45.561:  50%|█████     | 50/100 [08:23<08:15,  9.91s/it]

Epoch: 49, Train Loss: 45.561, Val Loss: 0.666, Val Acc: 0.741


Train loss: 46.205:  51%|█████     | 51/100 [08:33<08:05,  9.91s/it]

Epoch: 50, Train Loss: 46.205, Val Loss: 0.717, Val Acc: 0.755


Train loss: 46.858:  52%|█████▏    | 52/100 [08:43<07:55,  9.91s/it]

Epoch: 51, Train Loss: 46.858, Val Loss: 0.666, Val Acc: 0.769


Train loss: 47.503:  53%|█████▎    | 53/100 [08:53<07:45,  9.91s/it]

Epoch: 52, Train Loss: 47.503, Val Loss: 0.629, Val Acc: 0.783


Train loss: 48.111:  54%|█████▍    | 54/100 [09:03<07:35,  9.91s/it]

Epoch: 53, Train Loss: 48.111, Val Loss: 0.652, Val Acc: 0.778


Train loss: 48.723:  55%|█████▌    | 55/100 [09:13<07:26,  9.92s/it]

Epoch: 54, Train Loss: 48.723, Val Loss: 0.627, Val Acc: 0.801


Train loss: 49.326:  56%|█████▌    | 56/100 [09:23<07:16,  9.92s/it]

Epoch: 55, Train Loss: 49.326, Val Loss: 0.655, Val Acc: 0.761


Train loss: 49.906:  57%|█████▋    | 57/100 [09:33<07:06,  9.91s/it]

Epoch: 56, Train Loss: 49.906, Val Loss: 0.621, Val Acc: 0.778


Train loss: 50.509:  58%|█████▊    | 58/100 [09:43<06:56,  9.91s/it]

Epoch: 57, Train Loss: 50.509, Val Loss: 0.634, Val Acc: 0.772


Train loss: 51.116:  59%|█████▉    | 59/100 [09:52<06:42,  9.82s/it]

Epoch: 58, Train Loss: 51.116, Val Loss: 0.634, Val Acc: 0.772


Train loss: 51.696:  60%|██████    | 60/100 [10:01<06:26,  9.65s/it]

Epoch: 59, Train Loss: 51.696, Val Loss: 0.610, Val Acc: 0.769


Train loss: 52.265:  61%|██████    | 61/100 [10:11<06:12,  9.54s/it]

Epoch: 60, Train Loss: 52.265, Val Loss: 0.627, Val Acc: 0.775


Train loss: 52.866:  62%|██████▏   | 62/100 [10:20<05:59,  9.47s/it]

Epoch: 61, Train Loss: 52.866, Val Loss: 0.671, Val Acc: 0.772


Train loss: 53.451:  63%|██████▎   | 63/100 [10:29<05:48,  9.41s/it]

Epoch: 62, Train Loss: 53.451, Val Loss: 0.608, Val Acc: 0.778


Train loss: 54.022:  64%|██████▍   | 64/100 [10:38<05:36,  9.35s/it]

Epoch: 63, Train Loss: 54.022, Val Loss: 0.580, Val Acc: 0.801


Train loss: 54.574:  65%|██████▌   | 65/100 [10:48<05:26,  9.34s/it]

Epoch: 64, Train Loss: 54.574, Val Loss: 0.577, Val Acc: 0.772


Train loss: 55.115:  66%|██████▌   | 66/100 [10:57<05:17,  9.33s/it]

Epoch: 65, Train Loss: 55.115, Val Loss: 0.572, Val Acc: 0.783


Train loss: 55.646:  67%|██████▋   | 67/100 [11:06<05:07,  9.31s/it]

Epoch: 66, Train Loss: 55.646, Val Loss: 0.588, Val Acc: 0.775


Train loss: 56.186:  68%|██████▊   | 68/100 [11:16<04:57,  9.29s/it]

Epoch: 67, Train Loss: 56.186, Val Loss: 0.593, Val Acc: 0.781


Train loss: 56.701:  69%|██████▉   | 69/100 [11:25<04:47,  9.28s/it]

Epoch: 68, Train Loss: 56.701, Val Loss: 0.541, Val Acc: 0.803


Train loss: 57.226:  70%|███████   | 70/100 [11:34<04:37,  9.26s/it]

Epoch: 69, Train Loss: 57.226, Val Loss: 0.546, Val Acc: 0.809


Train loss: 57.751:  71%|███████   | 71/100 [11:43<04:28,  9.25s/it]

Epoch: 70, Train Loss: 57.751, Val Loss: 0.579, Val Acc: 0.786


Train loss: 58.293:  72%|███████▏  | 72/100 [11:53<04:24,  9.43s/it]

Epoch: 71, Train Loss: 58.293, Val Loss: 0.573, Val Acc: 0.801


Train loss: 58.820:  73%|███████▎  | 73/100 [12:03<04:18,  9.58s/it]

Epoch: 72, Train Loss: 58.820, Val Loss: 0.577, Val Acc: 0.798


Train loss: 59.335:  74%|███████▍  | 74/100 [12:13<04:11,  9.67s/it]

Epoch: 73, Train Loss: 59.335, Val Loss: 0.618, Val Acc: 0.769


Train loss: 59.846:  75%|███████▌  | 75/100 [12:23<04:03,  9.74s/it]

Epoch: 74, Train Loss: 59.846, Val Loss: 0.546, Val Acc: 0.821


Train loss: 60.353:  76%|███████▌  | 76/100 [12:33<03:54,  9.78s/it]

Epoch: 75, Train Loss: 60.353, Val Loss: 0.547, Val Acc: 0.795


Train loss: 60.869:  77%|███████▋  | 77/100 [12:43<03:45,  9.81s/it]

Epoch: 76, Train Loss: 60.869, Val Loss: 0.515, Val Acc: 0.818


Train loss: 61.361:  78%|███████▊  | 78/100 [12:53<03:36,  9.85s/it]

Epoch: 77, Train Loss: 61.361, Val Loss: 0.509, Val Acc: 0.818


Train loss: 61.850:  79%|███████▉  | 79/100 [13:02<03:27,  9.86s/it]

Epoch: 78, Train Loss: 61.850, Val Loss: 0.584, Val Acc: 0.766


Train loss: 62.344:  80%|████████  | 80/100 [13:12<03:17,  9.86s/it]

Epoch: 79, Train Loss: 62.344, Val Loss: 0.527, Val Acc: 0.803


Train loss: 62.824:  81%|████████  | 81/100 [13:22<03:07,  9.87s/it]

Epoch: 80, Train Loss: 62.824, Val Loss: 0.553, Val Acc: 0.803


Train loss: 63.309:  82%|████████▏ | 82/100 [13:32<02:57,  9.87s/it]

Epoch: 81, Train Loss: 63.309, Val Loss: 0.523, Val Acc: 0.809


Train loss: 63.790:  83%|████████▎ | 83/100 [13:42<02:47,  9.87s/it]

Epoch: 82, Train Loss: 63.790, Val Loss: 0.534, Val Acc: 0.812


Train loss: 64.257:  84%|████████▍ | 84/100 [13:52<02:38,  9.88s/it]

Epoch: 83, Train Loss: 64.257, Val Loss: 0.520, Val Acc: 0.809


Train loss: 64.727:  85%|████████▌ | 85/100 [14:02<02:28,  9.87s/it]

Epoch: 84, Train Loss: 64.727, Val Loss: 0.520, Val Acc: 0.798


Train loss: 65.207:  86%|████████▌ | 86/100 [14:12<02:17,  9.85s/it]

Epoch: 85, Train Loss: 65.207, Val Loss: 0.501, Val Acc: 0.812


Train loss: 65.677:  87%|████████▋ | 87/100 [14:21<02:08,  9.85s/it]

Epoch: 86, Train Loss: 65.677, Val Loss: 0.569, Val Acc: 0.786


Train loss: 66.141:  88%|████████▊ | 88/100 [14:31<01:57,  9.82s/it]

Epoch: 87, Train Loss: 66.141, Val Loss: 0.501, Val Acc: 0.821


Train loss: 66.597:  89%|████████▉ | 89/100 [14:41<01:47,  9.80s/it]

Epoch: 88, Train Loss: 66.597, Val Loss: 0.507, Val Acc: 0.801


Train loss: 67.074:  90%|█████████ | 90/100 [14:51<01:37,  9.77s/it]

Epoch: 89, Train Loss: 67.074, Val Loss: 0.556, Val Acc: 0.775


Train loss: 67.516:  91%|█████████ | 91/100 [15:00<01:27,  9.74s/it]

Epoch: 90, Train Loss: 67.516, Val Loss: 0.525, Val Acc: 0.806


Train loss: 67.978:  92%|█████████▏| 92/100 [15:10<01:18,  9.75s/it]

Epoch: 91, Train Loss: 67.978, Val Loss: 0.499, Val Acc: 0.798


Train loss: 68.416:  93%|█████████▎| 93/100 [15:20<01:08,  9.78s/it]

Epoch: 92, Train Loss: 68.416, Val Loss: 0.512, Val Acc: 0.823


Train loss: 68.847:  94%|█████████▍| 94/100 [15:30<00:58,  9.78s/it]

Epoch: 93, Train Loss: 68.847, Val Loss: 0.504, Val Acc: 0.803


Train loss: 69.299:  95%|█████████▌| 95/100 [15:40<00:49,  9.81s/it]

Epoch: 94, Train Loss: 69.299, Val Loss: 0.495, Val Acc: 0.821


Train loss: 69.747:  96%|█████████▌| 96/100 [15:49<00:39,  9.83s/it]

Epoch: 95, Train Loss: 69.747, Val Loss: 0.566, Val Acc: 0.792


Train loss: 70.201:  97%|█████████▋| 97/100 [15:59<00:29,  9.73s/it]

Epoch: 96, Train Loss: 70.201, Val Loss: 0.535, Val Acc: 0.812


Train loss: 70.651:  98%|█████████▊| 98/100 [16:08<00:19,  9.64s/it]

Epoch: 97, Train Loss: 70.651, Val Loss: 0.529, Val Acc: 0.801


Train loss: 71.090:  99%|█████████▉| 99/100 [16:18<00:09,  9.58s/it]

Epoch: 98, Train Loss: 71.090, Val Loss: 0.502, Val Acc: 0.812


Train loss: 71.538: 100%|██████████| 100/100 [16:27<00:00,  9.88s/it]


Epoch: 99, Train Loss: 71.538, Val Loss: 0.547, Val Acc: 0.792
