In [5]:
from models.swapnet import SWAP
import torch
from utils import  get_data_azimuths
from dataset import Seismic
import numpy as np
from pathlib import Path
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

class config:
    label_sigma = 10
    label_delta = 10
    use_gpu = torch.cuda.is_available()
    model_checkpoint_name = 'swapnet'
    model_checkpoint_path = Path().absolute()/f'models/pretrained/{model_checkpoint_name}.pt'
    lr = 1e-2
    epochs = 50

In [6]:
def train(model, trainset):
    dataloader = DataLoader(trainset, batch_size=256, shuffle=True)
    o = Adam(model.parameters(), lr=config.lr)
    s = StepLR(o, step_size=1, gamma= (1e-4/config.lr)**(1/config.epochs), last_epoch=-1)
    print('start training.')
    model.train()
    for epoch in range(config.epochs):
        for x, y in dataloader:
            if config.use_gpu:
                x, y = x.cuda(), y.cuda()
            loss = model.train_loss(x, y)
            o.zero_grad()
            loss.backward()
            o.step()
        s.step()
        print(f'\r{epoch}', end='')
    print()
    return model    


In [7]:
model = SWAP(c_in = 3, c_out = 360//config.label_delta, nf=32, adaptive_size=25, ks1 = [17, 11, 5], ks2=7)
if config.use_gpu:
    model = model.cuda()

if config.model_checkpoint_path.exists():
        print('checkpoint exists.')
        model.load_state_dict(torch.load(config.model_checkpoint_path))
else:

    trainset = Seismic(*get_data_azimuths('data', 'train'), signal_start=50, signal_length=200, 
                    label_sigma=config.label_sigma, 
                    label_delta=config.label_delta,
                    is_aug_shift=True)
    model = train(model, trainset)
    torch.save(model.state_dict(), config.model_checkpoint_path)

checkpoint exists.
