In [None]:
import numpy as np
import time
from scipy.sparse import coo_matrix, csr_matrix, vstack

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
assert(torch.cuda.is_available())
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True

In [None]:
Xtrain, ytrain, Xvalid, yvalid, Xtest, ytest = torch.load('build/extracted+reg.pt')

In [None]:
class Str2idx():
    def __init__(self, myset) -> None:
        self.idxDict = {}
        idx = 0
        for k in myset:
            self.idxDict[k] = idx
            idx += 1

    def __call__(self, query):
        if query in self.idxDict:
            return self.idxDict[query]
        return -1

validLabels = list(set(ytrain))
fit2idx = Str2idx(validLabels)

def LabelToIdx(data, bs=4096):
    return [fit2idx(x) for x in data]
    # mat = [fit2idx(x) for x in data]
    # res = []
    # for idx in range(0, len(mat), bs):
    #     m = mat[idx:idx+bs]
    #     res.append(torch.Tensor(m).long().cuda())
    # return res

ytrain, yvalid, ytest = LabelToIdx(ytrain), LabelToIdx(yvalid), LabelToIdx(ytest)

In [None]:
print(fit2idx.idxDict)

In [None]:
def ToOneHot(rawVec, numValues=5, bs=4096):
    rawVec = np.array(rawVec)
    oneHotMat = np.zeros((rawVec.shape[0], numValues))  # Initialize
    oneHotMat[np.arange(rawVec.shape[0]), rawVec] = 1   # `numValues`-dim one hot
    res = []
    for idx in range(0, oneHotMat.shape[0], bs):
        m = oneHotMat[idx:idx+bs]
        res.append(torch.FloatTensor(m).cuda())
    return res

ytrain, yvalid, ytest = ToOneHot(ytrain), ToOneHot(yvalid), ToOneHot(ytest)

In [None]:
def CsrToTorchSparse(csr, bs=4096):
    res = []
    for idx in range(0, csr.shape[0], bs):
        c = csr[idx:idx+bs]
        c = torch.sparse_csr_tensor(c.indptr, c.indices, c.data, c.shape, dtype=torch.float32)
        res.append(c.cuda())
    return res

Xtrain = CsrToTorchSparse(Xtrain)
Xvalid = CsrToTorchSparse(Xvalid)
Xtest = CsrToTorchSparse(Xtest)

In [None]:
class SparseDataset(Dataset):
    """
    Custom Dataset class for scipy sparse matrix
    """
    def __init__(self, data, targets):
        super().__init__()
        self.data = data                # CSR
        self.targets = targets          # Dense
        
    def __getitem__(self, index:int):
        return self.data[index], self.targets[index]

    def __len__(self):
        return len(self.data)
    

In [None]:
train_set = SparseDataset(Xtrain, ytrain)
valid_set = SparseDataset(Xvalid, yvalid)
test_set = SparseDataset(Xtest, ytest)

In [None]:
class MyMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Sequential(nn.Linear(1000, 512), nn.ReLU())
        self.fc2 = nn.Sequential(nn.Linear(512, 100), nn.ReLU())
        self.fc3 = nn.Sequential(nn.Linear(100, 5))
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

In [None]:
criterion = nn.CrossEntropyLoss(weight=torch.Tensor([2.35, 1.2, 0.45, 0.16, 0.09])).cuda()
# criterion = nn.CrossEntropyLoss().cuda()
bs = 4096

def MSE(pred, target):
    pred = 2 * (pred.argmax(dim=1) + 1).float()
    target = 2 * (target.argmax(dim=1) + 1).float()
    return torch.mean((pred - target)**2).cpu().item()

@torch.no_grad()
def val(model, dataset=valid_set):
    model.eval()
    val_loss = 0.0
    p = torch.Tensor([0, 0, 0, 0, 0])
    u = torch.Tensor([0, 0, 0, 0, 0])
    t = torch.Tensor([0, 0, 0, 0, 0])
    mse = 0.0
    for data, target in dataset:
        pred = model(data)
        loss = criterion(pred, target)
        val_loss += loss.data
        pred = nn.functional.one_hot(pred.argmax(dim=1), num_classes=5)
        p += torch.sum(pred, dim=0).cpu()
        u += torch.sum(pred*target, dim=0).cpu()
        t += torch.sum(target, dim=0).cpu()
        mse += MSE(pred, target)
    val_loss = val_loss / len(dataset)
    mse = mse / len(dataset)
    return val_loss, p, u, t, mse

def train(model, optimizer, epoch, lr_scheduler=None, grad_clip=None):
    start_time = time.time()
    model.train()
    optimizer.zero_grad()
    train_loss = 0
    for data, target in train_set:
        pred = model(data)
        loss = criterion(pred, target)
        train_loss += loss.data
        optimizer.zero_grad()
        loss.backward()
        if grad_clip: 
            nn.utils.clip_grad_value_(model.parameters(), grad_clip)
        optimizer.step()
    if lr_scheduler:
        lr_scheduler.step()
    train_loss = train_loss / len(train_set)
    val_loss, p, u, t, mse = val(model)
    f1 = 2 * u / (p + t)
    end_time = time.time()
    msg = f"Epoch: {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | u: {u} | p: {p} | t: {t} | Val F1: {f1} | Val MSE: {mse} | time: {end_time - start_time:.1f}"
    print(msg)

In [None]:
# No bias decay 
def create_param_groups(model):
    group_decay = []
    group_no_decay = []
    for m in model.modules():
        if isinstance(m, nn.Linear):
            group_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, nn.Conv2d):
            group_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
            if m.weight is not None:
                group_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
    assert(len(list(model.parameters())) == len(group_decay) + len(group_no_decay))
    return [dict(params=group_decay), dict(params=group_no_decay, weight_decay=0.0)]

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR

EP = 500
model = MyMLP().cuda()
optimizer = optim.SGD(
    # create_param_groups(model),
    model.parameters(),
    weight_decay=5e-4,
    lr = 0.5
)
lr_scheduler = CosineAnnealingLR(optimizer, EP)

for ep in range(EP):
    train(model, optimizer, ep, lr_scheduler, grad_clip=None)
    # break

In [None]:
val_loss, p, u, t, mse = val(model, test_set)
f1 = 2 * u / (p + t)
msg = f"Test Loss: {val_loss:.4f} | u: {u} | p: {p} | t: {t} | Val F1: {f1} | Val MSE: {mse}"
print(msg)