In [1]:
from utils_torch import *

proxy = torch_SpatioTemporalCO2()
device = proxy.check_torch_gpu()

train_dataloader, test_dataloader, train_dataset, test_dataset = proxy.make_dataloaders()

-------------------------------------------------
------------------ VERSION INFO -----------------
Conda Environment: Python39
Torch version: 2.0.0+cu117
Torch build with CUDA? True
# Device(s) available: 1, Name(s): NVIDIA GeForce RTX 3080



In [2]:
model = ProxyModel()
model.to(device)
print('# Params: {:,}'.format(count_params(model)))

optimizer = torch.optim.NAdam(model.parameters(), lr=8e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
criterion = CustomLoss().to(device)

num_epochs = 50
batch_size = 32

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    train_subset_size = int(len(train_dataset) * 0.8)  # 80% for training
    train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_subset_size, len(train_dataset) - train_subset_size])
    train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    for batch_idx, (x, y) in enumerate(train_dataloader):
        x, y = x.float().to(device), y.float().to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    tot_train_loss = train_loss/len(train_dataloader)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(valid_dataloader):
            x, y = x.float().to(device), y.float().to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            val_loss += loss.item()
    tot_valid_loss = val_loss/len(valid_dataloader)
    
    if (epoch+1) % 5 == 0:
        print('Epoch: [{}/{}] | Loss: {:.4f} | Validation Loss: {:.4f}'.format(epoch+1, num_epochs, tot_train_loss, tot_valid_loss))

print("Training finished.")

# Params: 373,274
