In [29]:
import torch
import xarray
from datasets import *
from models import *
from dl_pipeline import *
from loss import *
import seaborn
import numpy as np
import matplotlib.pyplot as plt

In [30]:
base_train_ds = CycloneDataset('/g/data/x77/ob2720/partition/train/', tracks_path=train_json_path, 
                            save_np=False, load_np=True, partition_name='train', synthetic=True, 
                            synthetic_type='base_synthesis', sigma=0.1)
normal_perturb_train_ds = CycloneDataset('/g/data/x77/ob2720/partition/train/', tracks_path=train_json_path, 
                            save_np=False, load_np=True, partition_name='train', synthetic=True, 
                            synthetic_type='normal_perturb_synthesis', sigma=0.1)

In [31]:
base_loader = torch.utils.data.DataLoader(base_train_ds, batch_size=256, num_workers=8, pin_memory=True, shuffle=True)
normal_perturb_loader = torch.utils.data.DataLoader(normal_perturb_train_ds, batch_size=256, num_workers=8, pin_memory=True, shuffle=True)

In [32]:
prediction_model = predictionANN(1)

In [33]:
print(len(base_train_ds))

188156


In [34]:
def train(model, train_loader, ds_length, num_epochs=5, batch_size=256):
    loss_fn = L2_Dist_Func_Mae().to(0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    
    for epoch in range(num_epochs):
        model.train()
        avg_loss = 0
        for i, data in tqdm(enumerate(train_loader), total = ds_length/batch_size):
            
            if data == []:
                continue
            else:
                example = data[0]
                label = data[1]
            
            pred = model.forward(example)
            optimizer.zero_grad(set_to_none=True)
            loss = loss_fn(pred, label)
            loss.backward()
            optimizer.step()
            avg_loss += loss.item()     

        print(f"Average loss: {loss}")
    
    return model

In [None]:
train(prediction_model, base_loader, len(base_train_ds), num_epochs=20)

100%|██████████| 735/734.984375 [03:29<00:00,  3.50it/s]

Average loss: 29.728195333291612



100%|██████████| 735/734.984375 [03:21<00:00,  3.66it/s]

Average loss: 31.21376988268088



100%|██████████| 735/734.984375 [03:21<00:00,  3.64it/s]

Average loss: 28.612871387648205



100%|██████████| 735/734.984375 [03:22<00:00,  3.63it/s]

Average loss: 28.294182322801106



100%|██████████| 735/734.984375 [03:22<00:00,  3.64it/s]

Average loss: 30.450344560165252



100%|██████████| 735/734.984375 [03:21<00:00,  3.64it/s]

Average loss: 28.369306607851907



100%|██████████| 735/734.984375 [03:22<00:00,  3.63it/s]

Average loss: 28.896483935060957



100%|██████████| 735/734.984375 [03:22<00:00,  3.63it/s]

Average loss: 30.25399226990957



100%|██████████| 735/734.984375 [03:23<00:00,  3.62it/s]

Average loss: 28.783015247848297



100%|██████████| 735/734.984375 [03:21<00:00,  3.65it/s]

Average loss: 28.97770292276428



100%|██████████| 735/734.984375 [03:21<00:00,  3.65it/s]

Average loss: 28.742693858487264



 92%|█████████▏| 674/734.984375 [03:05<00:13,  4.37it/s]

In [21]:
np.load('/g/data/x77/jm0124/synthetic_datasets/base_synthesis/u/2/0.1/train/1985240N16256-2.npy', allow_pickle=True).shape

(1, 1, 20, 20)

In [None]:
prediction_model = predictionANN(1)

In [36]:
train(prediction_model, base_loader, len(base_train_ds), num_epochs=10)
train(prediction_model, normal_perturb_loader, len(normal_perturb_train_ds), num_epochs=5)

100%|██████████| 735/734.984375 [03:22<00:00,  3.64it/s]

Average loss: 30.18215855908772



100%|██████████| 735/734.984375 [03:21<00:00,  3.65it/s]

Average loss: 30.482336185754292



100%|██████████| 735/734.984375 [03:20<00:00,  3.66it/s]

Average loss: 28.811776593090997



100%|██████████| 735/734.984375 [03:20<00:00,  3.66it/s]

Average loss: 30.152442959565963



100%|██████████| 735/734.984375 [03:20<00:00,  3.66it/s]

Average loss: 30.002978052884814



100%|██████████| 735/734.984375 [12:38<00:00,  1.03s/it]

Average loss: 30.966229676254212



100%|██████████| 735/734.984375 [03:20<00:00,  3.66it/s]

Average loss: 32.41251075930066



100%|██████████| 735/734.984375 [03:21<00:00,  3.66it/s]

Average loss: 28.9825183678241



100%|██████████| 735/734.984375 [03:21<00:00,  3.65it/s]

Average loss: 29.729422370592754



100%|██████████| 735/734.984375 [03:21<00:00,  3.64it/s]

Average loss: 31.92813471763853





predictionANN(
  (fc1): Linear(in_features=400, out_features=16, bias=True)
  (fc2): Linear(in_features=16, out_features=16, bias=True)
  (fc3): Linear(in_features=16, out_features=2, bias=True)
)

In [None]:
cyclone_ds = xarray.open_dataset('/g/data/x77/ob2720/partition/train/1985240N16256.nc', 
                            engine='netcdf4', decode_cf=True, cache=True)

cyclone_ds = cyclone_ds[dict(time=[1],
                        level=[2])]['u']
print(cyclone_ds)

cyclone_array = cyclone_ds.to_numpy()

cyclone_array = cyclone_array[:,:,40:120,40:120]
cyclone_array = cyclone_array[:,:,::4,::4]
size = 20

cyclone_array = cyclone_array.reshape(1, -1, size, size) 
np.save('/g/data/x77/jm0124/synthetic_datasets/base_synthesis/u/2/0.1/train/1985240N16256-2.npy', cyclone_array)