In [None]:
from pathlib import Path
import sys
sys.path.append('/Users/dean/Projects/mousetrap')

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

import mousetrap.mouse_dataset as mouse_dataset
import mousetrap.models as mtm

In [None]:
# load data
data_path = 'data/dsjtzs_txfz_training.txt'

tf = transforms.Compose([transforms.ToTensor()])
train_dataset = mouse_dataset.MouseMovementDataset(data_path, pad=True, transform=tf)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# create a simple convolutional model
class SimpleModel(torch.nn.Module):
    # note: simple model seems to converge at around epoch 20 with BCE
    def __init__(self, input_size: int = 300):
        super().__init__()
        self.conv1 = torch.nn.Conv1d(in_channels=3, out_channels=64, kernel_size=7)
        self.conv2 = torch.nn.Conv1d(in_channels=64, out_channels=64, kernel_size=5)
        self.conv3 = torch.nn.Conv1d(in_channels=64, out_channels=8, kernel_size=3)
        self.relu = torch.nn.ReLU()
        self.fc1 = torch.nn.Linear(in_features=(input_size - 7 - 5 - 3 + 3) * 8, out_features=256)
        self.fc2 = torch.nn.Linear(in_features=256, out_features=1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x: torch.Tensor):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = self.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x

class LinearModel(torch.nn.Module):
    def __init__(self, input_size: int = 300):
        super().__init__()
        self.fc1 = torch.nn.Linear(in_features=input_size * 3, out_features=512)
        self.fc2 = torch.nn.Linear(in_features=512, out_features=64)
        self.fc3 = torch.nn.Linear(in_features=64, out_features=1)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, x: torch.Tensor):
        x = torch.flatten(x, 1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

# initialize the model
net = SimpleModel()

# training
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([2600*1.0/400]))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

net, stats = mtm.train(net, 5, criterion, optimizer, train_dataloader)

print(f'Final stats: {stats}')