In [None]:
import os
import time
import torch
import torch.optim as optim

from datetime import date
from model import SUPRESCNN
from trainer import train, validation
from dataset_creater import datasetCreate
from get_dataset import getTrainDataset, getValDataset
from torch.utils.data.dataloader import DataLoader
from utils import config


In [None]:
device        =  config.DEVICE
criterion     =  config.LOSS
model         =  SUPRESCNN().to(device)
optimizer     =  optim.Adam(model.parameters(), config.LEARNING_RATE)

In [None]:
datasetCreate(config.train_dir, config.val_dir)

In [None]:
train_dataset = getTrainDataset("h5_file/train_h5")
train_loader  = DataLoader(dataset=train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)

val_dataset   = getValDataset("h5_file/eval_h5")
val_loader    = DataLoader(dataset=val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)

In [None]:
start = time.time()

for epoch in range(config.EPOCH):
    print(f"EPOCH {epoch+1} of {config.EPOCH}")
    
    train_epoch_loss = train(model,train_dataset, train_loader, device, criterion, optimizer, epoch)
    val_epoch_loss = validation(model, val_dataset, val_loader, device, criterion, optimizer, epoch)
    
end = time.time()

if not os.path.exists("output"):
    os.makedirs("output")

save_name = f"output/modelx{config.scale_factor}_{date.today()}.pth"
torch.save(model.state_dict(), save_name)
print(f"{((end-start)/60):.3f} minutes to train...")