In [1]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
import csv
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import Counter

In [3]:
pickle_path = "./train.pkl"

with open(pickle_path, 'rb') as f:
    data = pickle.load(f)

test_path = "./test_no_target.pkl"

with open(test_path, 'rb') as f:
    test_data = pickle.load(f)

In [4]:
print(f"data length: {len(data)}")
print(f"element type: {type(data[0])}")
print(f"song fragment: {data[0][0][:10]}")
print(f"song class: {data[0][1]}")

data length: 2939
element type: <class 'tuple'>
song fragment: [ -1.  -1.  -1.  -1. 144. 144. 144.  64.  67.   0.]
song class: 0


In [5]:
X, y = zip(*data)

## Rozkład klas

In [6]:
class_counts = Counter(y)
print("class counts:")
for key, value in class_counts.items():
    print(f"{key}: {value}")

class counts:
0: 1630
1: 478
2: 154
3: 441
4: 236


## Statystyki utworów

In [7]:
lengths = [len(song) for song in X]
print(f"average length: {np.mean(lengths)}")
print(f"max length: {np.max(lengths)}")
print(f"min length: {np.min(lengths)}")

average length: 436.50493365090165
max length: 6308
min length: 4


## Preprocessing

In [8]:
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

In [9]:
def prepare_data(data, test=False):
    if test:
        sequences = [torch.tensor(seq, dtype=torch.float32) for seq in data]
        labels = None
    else:
        sequences, labels = zip(*data)
        sequences = [torch.tensor(seq, dtype=torch.float32) for seq in sequences]
        labels = torch.tensor(labels, dtype=torch.long)
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long)
    return sequences, labels, lengths

In [21]:
def collate_fn(batch):
    if len(batch[0]) == 3:
        sequences, labels, lengths = zip(*batch)
        padded = pad_sequence(sequences, batch_first=True)
        return padded.unsqueeze(-1), torch.stack(labels), torch.stack(lengths)
    else:
        sequences, lengths = zip(*batch)
        padded = pad_sequence(sequences, batch_first=True)
        return padded.unsqueeze(-1), torch.stack(lengths)

## Dataloadery

In [23]:
def create_dataloader(data, batch_size, test=False):
    if test:
        sequences, _, lengths = prepare_data(data, test=True)
        dataset = list(zip(sequences, lengths))
    else:
        sequences, labels, lengths = prepare_data(data, test=False)
        dataset = list(zip(sequences, labels, lengths))
    return DataLoader(dataset, batch_size=batch_size, shuffle=not test, collate_fn=collate_fn)

In [24]:
BATCH_SIZE = 64
train_loader = create_dataloader(train_data, BATCH_SIZE)
val_loader = create_dataloader(val_data, BATCH_SIZE)
test_loader = create_dataloader(test_data, BATCH_SIZE, test=True)

## Model

In [13]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, out_size, bidirectional=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.dir_mult = 2 if bidirectional else 1

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size * self.dir_mult, out_size)

    def forward(self, x, lengths):
        h0 = torch.zeros(self.num_layers * self.dir_mult, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers * self.dir_mult, x.size(0), self.hidden_size).to(device)
        x_packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(x_packed, (h0, c0))
        if self.bidirectional:
            last_hidden = torch.cat((hn[-2], hn[-1]), dim=1)
        else:
            last_hidden = hn[-1]
        out = self.fc(last_hidden)
        return out

## Trening

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [15]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y, lengths in loader:
            x, y, lengths = x.to(device), y.to(device), lengths.to(device)
            outputs = model(x, lengths)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

In [16]:
def train(model, train_loader, val_loader, epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        for x, y, lengths in tqdm(train_loader, desc=f"Epoch {epoch}"):
            x, y, lengths = x.to(device), y.to(device), lengths.to(device)
            optimizer.zero_grad()
            outputs = model(x, lengths)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        val_acc = evaluate(model, val_loader)
        print(f"Epoch {epoch}, Loss: {total_loss:.4f}, Val Accuracy: {val_acc:.4f}")
    torch.save(model.state_dict(), "lstm_classifier.pt")

In [26]:
def generate_predictions(model, test_loader, output_file="pred.csv"):
    model.eval()
    predictions = []
    with torch.no_grad():
        for x, lengths in test_loader:
            x, lengths = x.to(device), lengths.to(device)
            outputs = model(x, lengths)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            predictions.extend(preds)
    with open(output_file, "w", newline="") as f:
        writer = csv.writer(f)
        for pred in predictions:
            writer.writerow([pred])
    print(f"Predictions saved to {output_file}")

In [18]:
model = LSTMClassifier(input_size=1, hidden_size=128, num_layers=2, out_size=5, bidirectional=True).to(device)
train(model, train_loader, val_loader, epochs=100)

Epoch 1: 100%|█████████████████████████████████████████| 37/37 [00:08<00:00,  4.20it/s]


Epoch 1, Loss: 46.5515, Val Accuracy: 0.5697


Epoch 2: 100%|█████████████████████████████████████████| 37/37 [00:08<00:00,  4.35it/s]


Epoch 2, Loss: 41.1233, Val Accuracy: 0.6173


Epoch 3: 100%|█████████████████████████████████████████| 37/37 [00:08<00:00,  4.30it/s]


Epoch 3, Loss: 37.2542, Val Accuracy: 0.6531


Epoch 4: 100%|█████████████████████████████████████████| 37/37 [00:08<00:00,  4.32it/s]


Epoch 4, Loss: 34.9844, Val Accuracy: 0.6361


Epoch 5: 100%|█████████████████████████████████████████| 37/37 [00:08<00:00,  4.52it/s]


Epoch 5, Loss: 33.1875, Val Accuracy: 0.6650


Epoch 6: 100%|█████████████████████████████████████████| 37/37 [00:08<00:00,  4.56it/s]


Epoch 6, Loss: 31.8000, Val Accuracy: 0.6888


Epoch 7: 100%|█████████████████████████████████████████| 37/37 [00:08<00:00,  4.46it/s]


Epoch 7, Loss: 30.7331, Val Accuracy: 0.6956


Epoch 8: 100%|█████████████████████████████████████████| 37/37 [00:08<00:00,  4.33it/s]


Epoch 8, Loss: 30.0626, Val Accuracy: 0.6854


Epoch 9: 100%|█████████████████████████████████████████| 37/37 [00:08<00:00,  4.52it/s]


Epoch 9, Loss: 29.4991, Val Accuracy: 0.7194


Epoch 10: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.45it/s]


Epoch 10, Loss: 27.4755, Val Accuracy: 0.7160


Epoch 11: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.51it/s]


Epoch 11, Loss: 26.1697, Val Accuracy: 0.7211


Epoch 12: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.49it/s]


Epoch 12, Loss: 25.6220, Val Accuracy: 0.7092


Epoch 13: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.51it/s]


Epoch 13, Loss: 24.0529, Val Accuracy: 0.7347


Epoch 14: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.33it/s]


Epoch 14, Loss: 23.6316, Val Accuracy: 0.7041


Epoch 15: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.41it/s]


Epoch 15, Loss: 22.9616, Val Accuracy: 0.7313


Epoch 16: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.40it/s]


Epoch 16, Loss: 21.1459, Val Accuracy: 0.7092


Epoch 17: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.30it/s]


Epoch 17, Loss: 22.0850, Val Accuracy: 0.7466


Epoch 18: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.42it/s]


Epoch 18, Loss: 20.2708, Val Accuracy: 0.7585


Epoch 19: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.36it/s]


Epoch 19, Loss: 18.3910, Val Accuracy: 0.7483


Epoch 20: 100%|████████████████████████████████████████| 37/37 [00:07<00:00,  4.68it/s]


Epoch 20, Loss: 17.4404, Val Accuracy: 0.7381


Epoch 21: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.51it/s]


Epoch 21, Loss: 17.1130, Val Accuracy: 0.7619


Epoch 22: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.54it/s]


Epoch 22, Loss: 15.3098, Val Accuracy: 0.7500


Epoch 23: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.37it/s]


Epoch 23, Loss: 15.3635, Val Accuracy: 0.7585


Epoch 24: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.27it/s]


Epoch 24, Loss: 13.4018, Val Accuracy: 0.7636


Epoch 25: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.30it/s]


Epoch 25, Loss: 11.8592, Val Accuracy: 0.7619


Epoch 26: 100%|████████████████████████████████████████| 37/37 [00:07<00:00,  4.64it/s]


Epoch 26, Loss: 10.8520, Val Accuracy: 0.7534


Epoch 27: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.36it/s]


Epoch 27, Loss: 10.1338, Val Accuracy: 0.7789


Epoch 28: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.38it/s]


Epoch 28, Loss: 9.1890, Val Accuracy: 0.7415


Epoch 29: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.34it/s]


Epoch 29, Loss: 9.6832, Val Accuracy: 0.7636


Epoch 30: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.49it/s]


Epoch 30, Loss: 8.1904, Val Accuracy: 0.7415


Epoch 31: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.27it/s]


Epoch 31, Loss: 7.2305, Val Accuracy: 0.7568


Epoch 32: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.50it/s]


Epoch 32, Loss: 5.7106, Val Accuracy: 0.7772


Epoch 33: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.33it/s]


Epoch 33, Loss: 5.3585, Val Accuracy: 0.7466


Epoch 34: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.39it/s]


Epoch 34, Loss: 4.8318, Val Accuracy: 0.7534


Epoch 35: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.37it/s]


Epoch 35, Loss: 4.3365, Val Accuracy: 0.7619


Epoch 36: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.33it/s]


Epoch 36, Loss: 3.7124, Val Accuracy: 0.7755


Epoch 37: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.32it/s]


Epoch 37, Loss: 3.5943, Val Accuracy: 0.7551


Epoch 38: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.23it/s]


Epoch 38, Loss: 2.6126, Val Accuracy: 0.7636


Epoch 39: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.31it/s]


Epoch 39, Loss: 4.7813, Val Accuracy: 0.7381


Epoch 40: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.27it/s]


Epoch 40, Loss: 7.5418, Val Accuracy: 0.7636


Epoch 41: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.32it/s]


Epoch 41, Loss: 4.9812, Val Accuracy: 0.7534


Epoch 42: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.30it/s]


Epoch 42, Loss: 3.6449, Val Accuracy: 0.7653


Epoch 43: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.53it/s]


Epoch 43, Loss: 2.1128, Val Accuracy: 0.7721


Epoch 44: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.36it/s]


Epoch 44, Loss: 1.2661, Val Accuracy: 0.7653


Epoch 45: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.35it/s]


Epoch 45, Loss: 0.9221, Val Accuracy: 0.7585


Epoch 46: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.48it/s]


Epoch 46, Loss: 0.5877, Val Accuracy: 0.7772


Epoch 47: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.58it/s]


Epoch 47, Loss: 0.4601, Val Accuracy: 0.7721


Epoch 48: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.28it/s]


Epoch 48, Loss: 0.3526, Val Accuracy: 0.7687


Epoch 49: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.48it/s]


Epoch 49, Loss: 0.2845, Val Accuracy: 0.7653


Epoch 50: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.37it/s]


Epoch 50, Loss: 0.2343, Val Accuracy: 0.7670


Epoch 51: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.33it/s]


Epoch 51, Loss: 0.2056, Val Accuracy: 0.7721


Epoch 52: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.47it/s]


Epoch 52, Loss: 0.1820, Val Accuracy: 0.7721


Epoch 53: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.33it/s]


Epoch 53, Loss: 0.1641, Val Accuracy: 0.7687


Epoch 54: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.40it/s]


Epoch 54, Loss: 0.1473, Val Accuracy: 0.7687


Epoch 55: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.36it/s]


Epoch 55, Loss: 0.1452, Val Accuracy: 0.7738


Epoch 56: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.43it/s]


Epoch 56, Loss: 0.1454, Val Accuracy: 0.7653


Epoch 57: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.52it/s]


Epoch 57, Loss: 0.1308, Val Accuracy: 0.7670


Epoch 58: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.39it/s]


Epoch 58, Loss: 0.1081, Val Accuracy: 0.7670


Epoch 59: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.45it/s]


Epoch 59, Loss: 0.0956, Val Accuracy: 0.7636


Epoch 60: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.48it/s]


Epoch 60, Loss: 0.0883, Val Accuracy: 0.7636


Epoch 61: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.41it/s]


Epoch 61, Loss: 0.0816, Val Accuracy: 0.7636


Epoch 62: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.48it/s]


Epoch 62, Loss: 0.0771, Val Accuracy: 0.7653


Epoch 63: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.25it/s]


Epoch 63, Loss: 0.0720, Val Accuracy: 0.7619


Epoch 64: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.48it/s]


Epoch 64, Loss: 0.0667, Val Accuracy: 0.7585


Epoch 65: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.28it/s]


Epoch 65, Loss: 0.0631, Val Accuracy: 0.7619


Epoch 66: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.44it/s]


Epoch 66, Loss: 0.0591, Val Accuracy: 0.7653


Epoch 67: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.44it/s]


Epoch 67, Loss: 0.0566, Val Accuracy: 0.7653


Epoch 68: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.18it/s]


Epoch 68, Loss: 0.0527, Val Accuracy: 0.7585


Epoch 69: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.49it/s]


Epoch 69, Loss: 0.0500, Val Accuracy: 0.7653


Epoch 70: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.44it/s]


Epoch 70, Loss: 0.0473, Val Accuracy: 0.7653


Epoch 71: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.38it/s]


Epoch 71, Loss: 0.0446, Val Accuracy: 0.7619


Epoch 72: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.47it/s]


Epoch 72, Loss: 0.0425, Val Accuracy: 0.7636


Epoch 73: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.52it/s]


Epoch 73, Loss: 0.0404, Val Accuracy: 0.7636


Epoch 74: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.39it/s]


Epoch 74, Loss: 0.0386, Val Accuracy: 0.7653


Epoch 75: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.36it/s]


Epoch 75, Loss: 0.0368, Val Accuracy: 0.7636


Epoch 76: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.28it/s]


Epoch 76, Loss: 0.0351, Val Accuracy: 0.7636


Epoch 77: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.38it/s]


Epoch 77, Loss: 0.0341, Val Accuracy: 0.7619


Epoch 78: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.44it/s]


Epoch 78, Loss: 0.0323, Val Accuracy: 0.7602


Epoch 79: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.40it/s]


Epoch 79, Loss: 0.0306, Val Accuracy: 0.7602


Epoch 80: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.33it/s]


Epoch 80, Loss: 0.0293, Val Accuracy: 0.7619


Epoch 81: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.33it/s]


Epoch 81, Loss: 0.0281, Val Accuracy: 0.7619


Epoch 82: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.38it/s]


Epoch 82, Loss: 0.0269, Val Accuracy: 0.7585


Epoch 83: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.45it/s]


Epoch 83, Loss: 0.0258, Val Accuracy: 0.7619


Epoch 84: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.30it/s]


Epoch 84, Loss: 0.0248, Val Accuracy: 0.7636


Epoch 85: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.29it/s]


Epoch 85, Loss: 0.0241, Val Accuracy: 0.7619


Epoch 86: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.39it/s]


Epoch 86, Loss: 0.0232, Val Accuracy: 0.7585


Epoch 87: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.44it/s]


Epoch 87, Loss: 0.0221, Val Accuracy: 0.7602


Epoch 88: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.40it/s]


Epoch 88, Loss: 0.0213, Val Accuracy: 0.7602


Epoch 89: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.49it/s]


Epoch 89, Loss: 0.0204, Val Accuracy: 0.7568


Epoch 90: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.41it/s]


Epoch 90, Loss: 0.0197, Val Accuracy: 0.7602


Epoch 91: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.43it/s]


Epoch 91, Loss: 0.0189, Val Accuracy: 0.7551


Epoch 92: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.46it/s]


Epoch 92, Loss: 0.0182, Val Accuracy: 0.7585


Epoch 93: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.42it/s]


Epoch 93, Loss: 0.0176, Val Accuracy: 0.7619


Epoch 94: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.41it/s]


Epoch 94, Loss: 0.0171, Val Accuracy: 0.7585


Epoch 95: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.27it/s]


Epoch 95, Loss: 0.0165, Val Accuracy: 0.7619


Epoch 96: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.29it/s]


Epoch 96, Loss: 0.0158, Val Accuracy: 0.7602


Epoch 97: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.44it/s]


Epoch 97, Loss: 0.0153, Val Accuracy: 0.7602


Epoch 98: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.40it/s]


Epoch 98, Loss: 0.0148, Val Accuracy: 0.7619


Epoch 99: 100%|████████████████████████████████████████| 37/37 [00:08<00:00,  4.46it/s]


Epoch 99, Loss: 0.0142, Val Accuracy: 0.7602


Epoch 100: 100%|███████████████████████████████████████| 37/37 [00:08<00:00,  4.27it/s]


Epoch 100, Loss: 0.0138, Val Accuracy: 0.7636


ValueError: not enough values to unpack (expected 3, got 2)

In [25]:
generate_predictions(model, test_loader)

Predictions saved to predictions.csv
