In [1]:
import os
import pandas as pd
import numpy as np
from scipy import stats
from pathlib import Path
from dotenv import load_dotenv
from tqdm import tqdm
import wandb

import torch
from torch import nn
from torch.utils.data import DataLoader

from src.metrics import pearson_metric
from src.torch_models import MLP
from src.data import Dataset


load_dotenv()

True

# Data preparation

In [2]:
dataset_dir = Path(os.environ['dataset_dir'])
data = pd.read_csv(dataset_dir / 'train.csv')

In [3]:
data = data.set_index('row_id')

In [4]:
# train on earlier data, test on later data
train = data[data.time_id < 1000]
test = data.query("1000 <= time_id")

In [5]:
train_dataset = Dataset(train)
test_dataset = Dataset(test)

# Hyperparameters

In [7]:
batch_size = 10000
epochs = 1

# Configure Training

In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [9]:
model = MLP(input_dim=302)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_function = nn.MSELoss()

In [10]:
wandb.init(project="market_prediction", entity="parmezano", name='mlp_test3')

[34m[1mwandb[0m: Currently logged in as: [33mimplausible_denyability[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


# Training

In [11]:
for i in range(epochs):
    model.train()
    train_losses = []
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        x, y_true = batch
        y_pred = model(x)
        loss = loss_function(y_true, y_pred.view(-1))
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        
    model.eval()
    test_losses = []
    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            x, y_true = batch
            y_pred = model(x)
            loss = loss_function(y_true, y_pred.view(-1))
            test_losses.append(loss.item())
    wandb.log({"train_loss": np.mean(train_losses, axis=0), "test_loss": np.mean(test_losses, axis=0)})

100%|██████████| 244/244 [01:17<00:00,  3.15it/s]
100%|██████████| 71/71 [00:11<00:00,  6.04it/s]
