In [1]:
import pathlib

import torch
import torch.nn.functional as F
import torch.optim as optim
from dataset import SequenceDataset
from sequence_transformations import TransformationRefined
from torch import Tensor, nn
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from training_loop import fit

In [2]:
print("GPU available:", torch.cuda.is_available())
print("Device id:", torch.cuda.current_device())
print("GPU:", torch.cuda.get_device_name(torch.cuda.current_device()))

GPU available: True
Device id: 0
GPU: Quadro T1000 with Max-Q Design


In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
DATASET_PATH = pathlib.Path("../data/classification/data.csv")
BATCH_SIZE = 128
SEQUENCE_LEN = 500

# "Refined" Represenation

In [5]:
dataset = SequenceDataset(DATASET_PATH, TransformationRefined())

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(
    dataset,
    [train_size, val_size],
)


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=(BATCH_SIZE * 2), shuffle=False)

## Baseline

In [11]:
class LogisticRegression(nn.Module):
    def __init__(self, sequence_len: int) -> None:
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(in_features=sequence_len * 2, out_features=1)

    def forward(self, x: Tensor) -> Tensor:  # B, T, C
        x = x.view(x.size(0), -1)  # B, T * C
        x = F.sigmoid(self.linear(x))  # B, 1
        x = x.squeeze()  # B
        return x

In [12]:
logistic_regression_model = LogisticRegression(SEQUENCE_LEN).to(device)
opt = optim.Adam(logistic_regression_model.parameters(), lr=0.005)

In [13]:
writer = SummaryWriter("runs/logistic_regression_refined")

In [14]:
fit(
    epochs=30,
    model=logistic_regression_model,
    loss_func=F.binary_cross_entropy,
    opt=opt,
    train_dl=train_loader,
    valid_dl=val_loader,
    writer=writer,
    device=device
)

100%|██████████| 30/30 [02:39<00:00,  5.32s/it]


In [15]:
writer.flush()