<a href="https://colab.research.google.com/github/czarodziejszyn/ssne/blob/main/projekt5/recursive_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [41]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [57]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
import csv
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import Counter

In [61]:
pickle_path = "/content/drive/MyDrive/data/train.pkl"

with open(pickle_path, 'rb') as f:
    data = pickle.load(f)

test_path = "/content/drive/MyDrive/data/test_no_target.pkl"

with open(test_path, 'rb') as f:
    test_data = pickle.load(f)

In [62]:
print(f"data length: {len(data)}")
print(f"element type: {type(data[0])}")
print(f"song fragment: {data[0][0][:10]}")
print(f"song class: {data[0][1]}")

data length: 2939
element type: <class 'tuple'>
song fragment: [ -1.  -1.  -1.  -1. 144. 144. 144.  64.  67.   0.]
song class: 0


In [63]:
X, y = zip(*data)

## Rozkład klas

In [64]:
class_counts = Counter(y)
print("class counts:")
for key, value in class_counts.items():
    print(f"{key}: {value}")

class counts:
0: 1630
1: 478
2: 154
3: 441
4: 236


## Statystyki utworów

In [65]:
lengths = [len(song) for song in X]
print(f"average length: {np.mean(lengths)}")
print(f"max length: {np.max(lengths)}")
print(f"min length: {np.min(lengths)}")

average length: 436.50493365090165
max length: 6308
min length: 4


## Preprocessing

In [66]:
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

In [67]:
def prepare_data(data, test=False):
    if test:
        sequences = [torch.tensor(seq, dtype=torch.float32) for seq in data]
        labels = None
    else:
        sequences, labels = zip(*data)
        sequences = [torch.tensor(seq, dtype=torch.float32) for seq in sequences]
        labels = torch.tensor(labels, dtype=torch.long)
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long)
    return sequences, labels, lengths

In [68]:
def collate_fn(batch):
    if isinstance(batch[0], tuple):
        sequences, labels, lengths = zip(*batch)
        padded = pad_sequence(sequences, batch_first=True)
        return padded.unsqueeze(-1), torch.stack(labels), torch.stack(lengths)
    else:
        sequences, lengths = zip(*batch)
        padded = pad_sequence(sequences, batch_first=True)
        return padded.unsqueeze(-1), torch.stack(lengths)

## Dataloadery

In [69]:
def create_dataloader(data, batch_size, test=False):
    if test:
        sequences, _, lengths = prepare_data(data, test=True)
        dataset = list(zip(sequences, lengths))
    else:
        sequences, labels, lengths = prepare_data(data, test=False)
        dataset = list(zip(sequences, labels, lengths))
    return DataLoader(dataset, batch_size=batch_size, shuffle=not test, collate_fn=collate_fn)

In [70]:
BATCH_SIZE = 64
train_loader = create_dataloader(train_data, BATCH_SIZE)
val_loader = create_dataloader(val_data, BATCH_SIZE)
test_loader = create_dataloader(test_data, BATCH_SIZE, test=True)

## Model

In [77]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, out_size, bidirectional=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.dir_mult = 2 if bidirectional else 1

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size * self.dir_mult, out_size)

    def forward(self, x, lengths):
        h0 = torch.zeros(self.num_layers * self.dir_mult, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers * self.dir_mult, x.size(0), self.hidden_size).to(device)
        x_packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(x_packed, (h0, c0))
        if self.bidirectional:
            last_hidden = torch.cat((hn[-2], hn[-1]), dim=1)
        else:
            last_hidden = hn[-1]
        out = self.fc(last_hidden)
        return out

## Trening

In [78]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [79]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y, lengths in loader:
            x, y, lengths = x.to(device), y.to(device), lengths.to(device)
            outputs = model(x, lengths)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

In [80]:
def train(model, train_loader, val_loader, epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        for x, y, lengths in tqdm(train_loader, desc=f"Epoch {epoch}"):
            x, y, lengths = x.to(device), y.to(device), lengths.to(device)
            optimizer.zero_grad()
            outputs = model(x, lengths)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        val_acc = evaluate(model, val_loader)
        print(f"Epoch {epoch}, Loss: {total_loss:.4f}, Val Accuracy: {val_acc:.4f}")
    torch.save(model.state_dict(), "lstm_classifier.pt")

In [81]:
def generate_predictions(model, test_loader, output_file="predictions.csv"):
    model.eval()
    predictions = []
    with torch.no_grad():
        for x, lengths in test_loader:
            x, lengths = x.to(device), lengths.to(device)
            outputs = model(x, lengths)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            predictions.extend(preds)
    with open(output_file, "w", newline="") as f:
        writer = csv.writer(f)
        for pred in predictions:
            writer.writerow([pred])
    print(f"Predictions saved to {output_file}")

In [None]:
model = LSTMClassifier(input_size=1, hidden_size=128, num_layers=2, out_size=5, bidirectional=True).to(device)
train(model, train_loader, val_loader, epochs=100)
generate_predictions(model, test_loader)

Epoch 1: 100%|██████████| 37/37 [00:10<00:00,  3.44it/s]


Epoch 1, Loss: 45.5369, Val Accuracy: 0.5680


Epoch 2: 100%|██████████| 37/37 [00:11<00:00,  3.33it/s]


Epoch 2, Loss: 40.3167, Val Accuracy: 0.6003


Epoch 3: 100%|██████████| 37/37 [00:11<00:00,  3.35it/s]


Epoch 3, Loss: 36.6782, Val Accuracy: 0.6565


Epoch 4: 100%|██████████| 37/37 [00:11<00:00,  3.35it/s]


Epoch 4, Loss: 34.6711, Val Accuracy: 0.6752


Epoch 5: 100%|██████████| 37/37 [00:11<00:00,  3.32it/s]


Epoch 5, Loss: 33.0453, Val Accuracy: 0.6599


Epoch 6: 100%|██████████| 37/37 [00:11<00:00,  3.29it/s]


Epoch 6, Loss: 31.0884, Val Accuracy: 0.7007


Epoch 7: 100%|██████████| 37/37 [00:11<00:00,  3.29it/s]


Epoch 7, Loss: 30.3786, Val Accuracy: 0.7007


Epoch 8: 100%|██████████| 37/37 [00:10<00:00,  3.53it/s]


Epoch 8, Loss: 29.2521, Val Accuracy: 0.7126


Epoch 9: 100%|██████████| 37/37 [00:10<00:00,  3.49it/s]


Epoch 9, Loss: 27.3089, Val Accuracy: 0.6735


Epoch 10: 100%|██████████| 37/37 [00:10<00:00,  3.57it/s]


Epoch 10, Loss: 27.4496, Val Accuracy: 0.7364


Epoch 11: 100%|██████████| 37/37 [00:11<00:00,  3.32it/s]


Epoch 11, Loss: 26.0200, Val Accuracy: 0.7024


Epoch 12: 100%|██████████| 37/37 [00:11<00:00,  3.29it/s]


Epoch 12, Loss: 26.0793, Val Accuracy: 0.6990


Epoch 13: 100%|██████████| 37/37 [00:11<00:00,  3.29it/s]


Epoch 13, Loss: 26.4850, Val Accuracy: 0.6922


Epoch 14: 100%|██████████| 37/37 [00:11<00:00,  3.25it/s]


Epoch 14, Loss: 23.7857, Val Accuracy: 0.7415


Epoch 15: 100%|██████████| 37/37 [00:10<00:00,  3.40it/s]


Epoch 15, Loss: 22.5830, Val Accuracy: 0.7381


Epoch 16: 100%|██████████| 37/37 [00:11<00:00,  3.35it/s]


Epoch 16, Loss: 21.7266, Val Accuracy: 0.7483


Epoch 17: 100%|██████████| 37/37 [00:11<00:00,  3.09it/s]


Epoch 17, Loss: 22.3726, Val Accuracy: 0.7296


Epoch 18: 100%|██████████| 37/37 [00:12<00:00,  3.06it/s]


Epoch 18, Loss: 20.0523, Val Accuracy: 0.7636


Epoch 19: 100%|██████████| 37/37 [00:11<00:00,  3.18it/s]


Epoch 19, Loss: 18.7747, Val Accuracy: 0.7551


Epoch 20: 100%|██████████| 37/37 [00:11<00:00,  3.28it/s]


Epoch 20, Loss: 18.5170, Val Accuracy: 0.7364


Epoch 21: 100%|██████████| 37/37 [00:11<00:00,  3.28it/s]


Epoch 21, Loss: 18.3670, Val Accuracy: 0.7398


Epoch 22: 100%|██████████| 37/37 [00:11<00:00,  3.30it/s]


Epoch 22, Loss: 17.0652, Val Accuracy: 0.7500


Epoch 23: 100%|██████████| 37/37 [00:11<00:00,  3.33it/s]


Epoch 23, Loss: 15.7180, Val Accuracy: 0.7534


Epoch 24: 100%|██████████| 37/37 [00:11<00:00,  3.23it/s]


Epoch 24, Loss: 14.2454, Val Accuracy: 0.7568


Epoch 25: 100%|██████████| 37/37 [00:10<00:00,  3.41it/s]


Epoch 25, Loss: 13.0679, Val Accuracy: 0.7534


Epoch 26: 100%|██████████| 37/37 [00:11<00:00,  3.32it/s]


Epoch 26, Loss: 12.7049, Val Accuracy: 0.7313


Epoch 27: 100%|██████████| 37/37 [00:11<00:00,  3.18it/s]


Epoch 27, Loss: 11.8936, Val Accuracy: 0.7602


Epoch 28: 100%|██████████| 37/37 [00:11<00:00,  3.23it/s]


Epoch 28, Loss: 10.4226, Val Accuracy: 0.7636


Epoch 29: 100%|██████████| 37/37 [00:11<00:00,  3.25it/s]


Epoch 29, Loss: 9.0765, Val Accuracy: 0.7585


Epoch 30: 100%|██████████| 37/37 [00:11<00:00,  3.28it/s]


Epoch 30, Loss: 9.1369, Val Accuracy: 0.7738


Epoch 31: 100%|██████████| 37/37 [00:11<00:00,  3.25it/s]


Epoch 31, Loss: 8.2945, Val Accuracy: 0.7619


Epoch 32: 100%|██████████| 37/37 [00:11<00:00,  3.30it/s]


Epoch 32, Loss: 7.7096, Val Accuracy: 0.7585


Epoch 33: 100%|██████████| 37/37 [00:11<00:00,  3.26it/s]
