In [1]:
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from torch_geometric.data import DataLoader
from tqdm import tqdm

In [2]:
from dataset import COVIDDataset
from st_gnn import SpatioTemporalGCN

In [3]:
train_dataset = COVIDDataset('train')
test_dataset = COVIDDataset('test', normalization_params=train_dataset.normalization_params)

In [10]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

In [11]:
class RMSLELoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = torch.nn.MSELoss()
        
    def forward(self, pred, actual, mask=None):
        if mask is None:
            return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1)))
        else:
            return torch.sqrt(self.mse(torch.log(pred[mask] + 1), torch.log(actual[mask] + 1)))

In [12]:
device = "cuda:0"

model = SpatioTemporalGCN(num_temp_features=8, num_static_features=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
writer = SummaryWriter()
criterion = torch.nn.MSELoss()
test_criterion = RMSLELoss()

def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out.squeeze(), data.y)
        loss.backward()
        total_loss += loss.item() 
        optimizer.step()
    return total_loss / len(train_dataset)

def test(loader):
    model.eval()
    data = next(iter(loader)).to(device)
    with torch.no_grad():
        pred_y = model(data)
    mse = torch.sqrt(torch.mean((pred_y.squeeze() - data.y)**2))
    rmsle = test_criterion(pred_y.squeeze(), data.y)
    rmsle_20 = test_criterion(pred_y.squeeze(), data.y, data.tc.type(torch.bool))
    return mse, rmsle, rmsle_20

In [13]:
def save(model, optimizer, epoch):
    PATH = "models/modeld_%d.pt" % epoch
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, PATH)

In [None]:
for epoch in tqdm(range(1, 100000)):
    loss = train()
    writer.add_scalar('loss', loss, epoch)
    if epoch % 10 == 0:
        train_mse, train_acc, train_acc_20 = test(train_loader)
        test_mse, test_acc, test_acc_20 = test(test_loader)
        writer.add_scalar('MSE/train', train_mse, epoch)
        writer.add_scalar('RMSLE/train', train_acc, epoch)
        writer.add_scalar('RMSLE/train_top_20', train_acc_20, epoch)
        writer.add_scalar('MSE/test', test_mse, epoch)
        writer.add_scalar('RMSLE/test', test_acc, epoch)
        writer.add_scalar('RMSLE/test_top_20', test_acc_20, epoch)
        # print('Epoch {:03d}, Loss: {:.4f}, Train: {:.4f}, Test: {:.4f}'.format(epoch, loss, train_acc, train_mse))
        # print('Epoch {:03d}, Loss: {:.4f}, Train: {:.4f}, Test: {:.4f}'.format(epoch, loss, test_acc, test_mse))
    if epoch % 50 == 0:
        save(model, optimizer, epoch)

  0%|          | 72/99999 [02:10<52:49:44,  1.90s/it]

In [24]:
test(test_loader)

(tensor(0.1221, device='cuda:0'),
 tensor(0.0228, device='cuda:0'),
 tensor(0.0078, device='cuda:0'))

In [None]:
4750