## Training
Prepapre dataset with the prepare_dataset notebook, before running this one.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import scipy.io as sio
from time import time, sleep
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")  # adds seaborn style to charts, eg. grid
plt.style.use("dark_background")  # inverts colors to dark theme
plt.rcParams['font.family'] = 'monospace'
import os
import warnings; warnings.filterwarnings("ignore")
from utils import calc_gso_batch # gso/pinn calculation
try: 
    JOBID = os.environ["SLURM_JOB_ID"] # get job id from slurm, when training on cluster
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # nvidia
    HAS_SCREEN = False # for plotting or saving images
except:
    device = torch.device("mps") # apple silicon
    JOBID = "local"
    HAS_SCREEN = True
os.makedirs(f"mg_data/{JOBID}", exist_ok=True)
print(f'device: {device}')

# copy the python training to the directory (for cluster) (for local, it fails silently)
os.system(f"cp mg_train2.py mg_data/{JOBID}/mg_train2.py")
os.system(f"cp utils.py mg_data/{JOBID}/utils.py")

def to_tensor(x, device=torch.device("cpu")): return torch.tensor(x, dtype=torch.float32, device=device)

SMALL, NORM, BIG = "small", "norm", "big"
PRENAME_MSE, PRENAME_GSO, PRENAME_TOT = "mg_planet_mse", "mg_planet_gso", "mg_planet_tot"

In [2]:
SAVE_DIR = f"mg_data/{JOBID}" 
EPOCHS = 1000 # number of epochs, note: needs to 
BATCH_SIZE = 128 # 128 best

LOAD_PRETRAINED = None # Set it to None if you don't want to load pretrained model
# LOAD_PRETRAINED = "trained_models/pretrained_1809761.pth" # norm model
# LOAD_PRETRAINED = "trained_models/pretrained_small_1810888.pth" # small model
# LOAD_PRETRAINED = "trained_models/pretrained_big_1811142.pth" # big model

# LEARNING_RATE = 3e-4*np.linspace(1, 1e-2, EPOCHS)  # best
LEARNING_RATE = 3e-4*np.logspace(0, -2, EPOCHS)
# LEARNING_RATE = 1e-4*np.logspace(0, -2, EPOCHS)

# GSO_LOSS_RATIO = np.linspace(0.4, 0.1, EPOCHS) # best
# GSO_LOSS_RATIO = np.linspace(0.3, 0.1, EPOCHS) # best too
# GSO_LOSS_RATIO = np.linspace(0.4, 0.0, EPOCHS) # best for big model pretrain start
GSO_LOSS_RATIO = np.concatenate((np.linspace(0.4, 0.0, EPOCHS//2), np.linspace(0.0, 0.0, EPOCHS//2))) 
# GSO_LOSS_RATIO = 0.1*np.ones(EPOCHS) # not very good
# GSO_LOSS_RATIO = (0.5+0.5*np.sin(np.linspace(0, 25*np.pi, EPOCHS)))*np.linspace(1, 0.1, EPOCHS) # crazy

NCURRS, NPROFS, NMAGS = 14, 202, 187 # input sizes
INPUT_SIZE = NCURRS + NPROFS + NMAGS
TRAIN_DS_PATH = "data/train_ds.mat" # generated from prepapre_dataset
EVAL_DS_PATH = "data/eval_ds.mat"

In [3]:
#checks
if LOAD_PRETRAINED is not None: assert os.path.exists(LOAD_PRETRAINED), "Pretrained model does not exist"
assert os.path.exists(TRAIN_DS_PATH), "Training dataset does not exist"
assert os.path.exists(EVAL_DS_PATH), "Evaluation dataset does not exist"
assert os.path.exists(SAVE_DIR), "Save directory does not exist"
assert len(LEARNING_RATE) == EPOCHS, "Learning rate array length does not match epochs"
assert len(GSO_LOSS_RATIO) == EPOCHS, "GSO loss ratio array length does not match epochs"

# Best train runs: 
| ID       | NET | MSE    | GS0    | Pre | Notes |
|----------|-----|--------|--------|-----|-------|
| 1810019  | norm | 0.0072 | 0.0730 | / | / |
| 1809989  | norm | 0.0082 | 0.3061 | / | / |
| 1809986  | norm | 0.0050 | 0.0727 | / | / |
| 1809768  | norm | 0.0030 | 0.0500 | / | / |
| 1809761  | norm | 0.0060 | 0.0840 | / | / |
| 1810294  | norm | 0.0026 | 0.0487 | 1809768 | pre-trained from 1809768 |
| 1810825  | small | 0.0265 | 0.2606 | / | small model |
| 1810888  | small | 0.0234 | 0.2236 | / | has potential for improvement |
| 1810897  | small | 0.0388 | 0.2940 | / | start lr from 1e-4, bs 64| 
| 1810903  | norm | 0.0232 | 0.1345 | / | 0.1 const gso ratio (bad?) |
| 1811116  | norm | 0.0141 | 0.0768 | / | 0.3 const gso ratio |
| 1811117  | big | 0.0024 | 0.0601 | / | 0.1 const gso ratio, lr dec 3e-4 | 
| 1811142  | big | 0.0019 | 0.0535 | / | 0.1 const gso ratio, lr dec 3e-4 |
| 1811143  | big | 0.0020 | 0.0609 | / | exact same as 1811142 |
| 1811302 | big | 0.0063 | 0.0500 | / | gso .3 -> .1, lr dec 3e-4 |
| 1811304 | big | 0.0013 | 0.0310 | 1811142 | gso .3 -> .1, lr dec 3e-4, first time train loss < eval |
| 1814866 | big | 0.0012 | 0.0304 | 1811142 | gso .3 -> 0.0 
| 1814867 | big | 0.0012 | **0.0277** | 1811142 | gso .3 -> 0.0, repeat 1814866 |
| 1817256 | big | 0.0022 | 0.0512 | 1811142 | batch size 256 (bad) |
| 1817333 | big | **0.0009** | 0.0333 | 1811142 | same as 1814866, but keep 0.0 from ep 500-1000 |
| 1823444 | big | 0.0014 | 0.0579 | / | same as 1817333, but from scratch |
|||||||



In [None]:
# plot schedulers: lr + gso loss ratio
fig, ax = plt.subplots(1, 2, figsize=(12, 3))
ax[0].set_title("Learning Rate [Log]")
ax[0].plot(LEARNING_RATE, color="red")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Learning Rate")
ax[0].set_yscale("log")
ax[1].set_title("GSO Loss Ratio")
ax[1].plot(GSO_LOSS_RATIO, color="red")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("GSO Loss Ratio")
plt.tight_layout()
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/schedulers.png")

| Measurement    | Mean        | Standard Deviation |
|----------------|-------------|--------------------|
| Current        | -10183.76   | 34209.11           |
| Magnetic       | -0.20       | 0.58               |
| F Profile      | 33.13       | 0.28               |
| P Profile      | 9654.42     | 8788.29            |

In [5]:
class PlaNetDataset(Dataset):
    def __init__(self, ds_mat_path):
        ds_mat = sio.loadmat(ds_mat_path)
        # output: magnetic flux, transposed (matlab is column-major)
        self.psi = to_tensor(ds_mat["psi"]).view(-1, 1, 64, 64)
        # inputs: radial and vertical position of pixels (for plotting only rn) + currents + measurements + profiles 
        self.rr = to_tensor(ds_mat["rr"]).view(-1,1,64,64) # radial position of pixels (64, 64)
        self.zz = to_tensor(ds_mat["zz"]).view(-1,1,64,64) # vertical position of pixels (64, 64)
        self.currs = ds_mat["currs"] # input currents (n, 14)
        self.mags = ds_mat["magnetic"] # input magnetic measurements (n, 187)
        f_prof = ds_mat["f_profiles"] # input profiles (n, 101)
        p_prof = ds_mat["p_profiles"] # input profiles (n, 101)
        self.currs = (to_tensor(self.currs)+10183)/34209 # (n, 14) # normalized
        self.mags = (to_tensor(self.mags)+0.2)/0.58 # (n, 187) # normalized
        self.profs = torch.cat(((to_tensor(f_prof)-33.13)/0.28, (to_tensor(p_prof)-9654)/8788), 1) # (n, 202) # normalized
        # move to device (doable bc the dataset is fairly small, check memory usage)
        self.currs, self.mags, self.profs = self.currs.to(device), self.mags.to(device), self.profs.to(device)
        self.psi, self.rr, self.zz = self.psi.to(device), self.rr.to(device), self.zz.to(device)
        self.everything = [self.currs, self.mags, self.profs, self.psi, self.rr, self.zz]
        print(f"Dataset: {len(self)}, memory: {sum([x.element_size()*x.nelement() for x in self.everything])/1024**2:.0f} MB")
    def __len__(self): return len(self.psi)
    def __getitem__(self, idx): return [x[idx] for x in self.everything]

In [None]:
# test dataset
ds = PlaNetDataset(EVAL_DS_PATH)
print(f"Dataset length: {len(ds)}")
print(f"Input shape: {ds[0][0].shape}")
print(f"Output shape: {ds[0][1].shape}")
n_plot = 10
print(len(ds))
fig, axs = plt.subplots(1, n_plot, figsize=(3*n_plot, 5))
for i, j in enumerate(np.random.randint(0, len(ds), n_plot)):
    psi, rr, zz = ds[j][3].cpu().numpy().squeeze(), ds[j][4].cpu().numpy().squeeze(), ds[j][5].cpu().numpy().squeeze()
    axs[i].contourf(rr, zz, psi, 100, cmap="inferno")
    axs[i].contour(rr, zz, -psi, 20, colors="black", linestyles="dotted")
    axs[i].axis("off")
    axs[i].set_aspect("equal")
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/dataset.png")

### Network architecture
![|100](mg_data/planet_eq_net.jpg)

In [7]:
# MODEL: PlaNet: # Paper net: branch + trunk conenction and everything 
from torch.nn import Module, Linear, Conv2d, MaxPool2d, BatchNorm2d, ReLU, Sequential, ConvTranspose2d
Λ = ReLU() # activation function
class Head(Module): 
    def __init__(self):
        super(Head, self).__init__()
        self.fc = Sequential(Linear(64, 4096), Λ)
        self.anti_conv = Sequential( # U-Net style
            ConvTranspose2d(64, 64, kernel_size=2, stride=2), 
            Conv2d(64, 64, kernel_size=3, padding=0), Λ,
            Conv2d(64, 64, kernel_size=3, padding=0), Λ,
            ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            Conv2d(32, 32, kernel_size=3, padding=0), Λ,
            Conv2d(32, 32, kernel_size=3, padding=0), Λ,
            ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            Conv2d(16, 16, kernel_size=3, padding=0), Λ,
            Conv2d(16, 16, kernel_size=3, padding=0), Λ,
            ConvTranspose2d(16, 8, kernel_size=2, stride=2),
            Conv2d(8, 4, kernel_size=3, padding=0), Λ,
            Conv2d(4, 2, kernel_size=3, padding=0), Λ,
            Conv2d(2, 1, kernel_size=5, padding=0),
        )
    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 64, 8, 8)
        x = self.anti_conv(x)
        return x
    def save(self, dir, prename): torch.save(self.state_dict(), f"{dir}/{prename}_head.pth")
    def load(self, dir, prename): self.load_state_dict(torch.load(f"{dir}/{prename}_head.pth"))
    
##########################################################################################
class Trunk(Module): 
    def __init__(self):
        super(Trunk, self).__init__()
        def trunk_block(): 
            return  Sequential(
                Conv2d(1, 8, kernel_size=3, stride=1, padding=1), BatchNorm2d(8), Λ, MaxPool2d(2),
                Conv2d(8, 16, kernel_size=3, stride=1, padding=1), BatchNorm2d(16), Λ, MaxPool2d(2),
                Conv2d(16, 32, kernel_size=3, stride=1, padding=1), BatchNorm2d(32), Λ, MaxPool2d(2),
            )
        self.trunk_r, self.trunk_z = trunk_block(), trunk_block()
        self.trunk_fc = Sequential(
            Linear(2*32*8*8, 128), Λ,
            Linear(128, 64), Λ,
            Linear(64, 64), Λ, 
        )
    def forward(self, x):
        r, z = x # split inputs
        r, z = self.trunk_r(r), self.trunk_z(z) # convolutions
        r, z = r.view(-1, 32*8*8), z.view(-1, 32*8*8) # flatten
        xt = torch.cat((r, z), 1) # concatenate
        xt = self.trunk_fc(xt) # fully connected
        return xt
    def save(self, dir, prename): torch.save(self.state_dict(), f"{dir}/{prename}_trunk.pth")
    def load(self, dir, prename): self.load_state_dict(torch.load(f"{dir}/{prename}_trunk.pth"))
    
##########################################################################################
class Branch(Module): 
    def __init__(self, input_size):
        super(Branch, self).__init__()
        self.branch = Sequential(
            Linear(input_size, 256), Λ,
            Linear(256, 128), Λ,
            Linear(128, 64), Λ
        )
    def forward(self, xb): return self.branch(xb)
    def save(self, dir, prename): torch.save(self.state_dict(), f"{dir}/{prename}_branch.pth")
    def load(self, dir, prename): self.load_state_dict(torch.load(f"{dir}/{prename}_branch.pth"))

##########################################################################################
class PlaNet(Module):
    def __init__(self, branch:Branch, trunk:Trunk, head:Head):
        super(PlaNet, self).__init__()
        assert isinstance(branch, Branch) and isinstance(trunk, Trunk) and isinstance(head, Head)
        self.branch, self.trunk, self.head = branch, trunk, head
    def forward(self, x):
        xb, r, z = x # split inputs
        xb = self.branch(xb) # branch
        xt = self.trunk((r, z)) # trunk
        x = xb * xt # element-wise multiplication
        x = self.head(x) # head
        return x
    def save(self, dir, prename): 
        self.branch.save(dir, prename)
        self.trunk.save(dir, prename)
        self.head.save(dir, prename)
    def load(self, dir, prename):
        self.branch.load(dir, prename)
        self.trunk.load(dir, prename)
        self.head.load(dir, prename)
        print(f"Model loaded from {dir}/{prename}")

In [None]:
# test input/output shapes
x = (torch.rand(1, INPUT_SIZE), torch.rand(1, 1, 64, 64), torch.rand(1, 1, 64, 64))
net = PlaNet(Branch(INPUT_SIZE), Trunk(), Head())
y = net(x)
print(f"in: {[x.shape for x in x]}, out: {y.shape}")
n_sampl = 7
nx = (torch.rand(n_sampl, INPUT_SIZE), torch.rand(n_sampl, 1, 64, 64), torch.rand(n_sampl, 1, 64, 64))
ny = net(nx)
print(f"in: {[x.shape for x in nx]}, out: {ny.shape}")
assert ny.shape == (n_sampl, 1, 64, 64), f"Wrong output shape: {ny.shape}"

## Training

In [None]:
# training 
def train():
    train_ds, val_ds = PlaNetDataset(TRAIN_DS_PATH), PlaNetDataset(EVAL_DS_PATH) # initialize datasets
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) # initialize DataLoader
    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)  
    trunk, head = Trunk(), Head() # initialize modules (common modules)
    branch1, branch2, branch3 = Branch(NCURRS+NMAGS+NPROFS), Branch(NCURRS+NMAGS), Branch(NCURRS) # 3 different branches
    model1 = PlaNet(branch1, trunk, head) # initialize model 1 
    model2 = PlaNet(branch2, trunk, head) # initialize model 2
    model3 = PlaNet(branch3, trunk, head) # initialize model 3
    if LOAD_PRETRAINED is not None: raise NotImplementedError("Pretrained model loading is not implemented yet")
    model1.to(device), model2.to(device), model3.to(device) # move to device
    optimizer = torch.optim.Adam(list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters()), lr=LEARNING_RATE[0]) # optimizer
    loss_fn = torch.nn.MSELoss() # Mean Squared Error Loss
    tlog_mse1, tlog_gso1, elog_mse1, elog_gso1 = [], [], [], [] # logs for losses
    tlog_mse2, tlog_gso2, elog_mse2, elog_gso2 = [], [], [], [] # logs for losses
    tlog_mse3, tlog_gso3, elog_mse3, elog_gso3 = [], [], [], [] # logs for losses
    start_time = time() # start time
    for ep in range(EPOCHS): 
        epoch_time = time()
        for pg in optimizer.param_groups: pg['lr'] = LEARNING_RATE[ep] # update learning rate
        model1.train(), model2.train() # training mode
        trainloss, evalloss = [], []
        trainloss1, evalloss1 = [], []
        trainloss2, evalloss2 = [], []
        trainloss3, evalloss3 = [], []
        for curr, mag, prof, psi, rr, zz in train_dl:
            input1, input2, input3 = torch.cat((curr, mag, prof), 1), torch.cat((curr, mag), 1), curr # concatenate inputs
            optimizer.zero_grad() # zero gradients
            psi_pred1 = model1((input1, rr, zz)) # forward pass
            psi_pred2 = model2((input2, rr, zz))
            psi_pred3 = model3((input3, rr, zz))
            gso = calc_gso_batch(psi, rr, zz, dev=device) # calculate grad shafranov
            gso_pred1 = calc_gso_batch(psi_pred1, rr, zz, dev=device)
            gso_pred2 = calc_gso_batch(psi_pred2, rr, zz, dev=device)
            gso_pred3 = calc_gso_batch(psi_pred3, rr, zz, dev=device)
            mse_loss1, gso_loss1 = loss_fn(psi_pred1, psi), loss_fn(gso_pred1, gso) # losses
            mse_loss2, gso_loss2 = loss_fn(psi_pred2, psi), loss_fn(gso_pred2, gso) 
            mse_loss3, gso_loss3 = loss_fn(psi_pred3, psi), loss_fn(gso_pred3, gso) 
            mse_loss, gso_loss = (mse_loss1+mse_loss2+mse_loss3)/3, (gso_loss1+gso_loss2+gso_loss3)/3 # average losses
            loss = (1-GSO_LOSS_RATIO[ep])*mse_loss + GSO_LOSS_RATIO[ep]*gso_loss # total loss
            loss.backward() # backprop
            optimizer.step() # update weights
            trainloss.append((loss.item(), mse_loss.item(), gso_loss.item())) # save batch losses
            trainloss1.append((mse_loss1.item(), gso_loss1.item())) 
            trainloss2.append((mse_loss2.item(), gso_loss2.item())) 
            trainloss3.append((mse_loss3.item(), gso_loss3.item()))
        model1.eval(), model2.eval(), model3.eval() # evaluation mode
        with torch.no_grad():
            for curr, mag, prof, psi, rr, zz in val_dl:
                input1, input2, input3 = torch.cat((curr, mag, prof), 1), torch.cat((curr, mag), 1), curr # concatenate inputs
                psi_pred1 = model1((input1, rr, zz))
                psi_pred2 = model2((input2, rr, zz))
                psi_pred3 = model3((input3, rr, zz))
                gso = calc_gso_batch(psi, rr, zz, dev=device)
                gso_pred1 = calc_gso_batch(psi_pred1, rr, zz, dev=device)
                gso_pred2 = calc_gso_batch(psi_pred2, rr, zz, dev=device)
                gso_pred3 = calc_gso_batch(psi_pred3, rr, zz, dev=device)
                mse_loss1, gso_loss1 = loss_fn(psi_pred1, psi), loss_fn(gso_pred1, gso)
                mse_loss2, gso_loss2 = loss_fn(psi_pred2, psi), loss_fn(gso_pred2, gso)
                mse_loss3, gso_loss3 = loss_fn(psi_pred3, psi), loss_fn(gso_pred3, gso)
                mse_loss, gso_loss = (mse_loss1+mse_loss2+mse_loss3)/3, (gso_loss1+gso_loss2+gso_loss3)/3
                evalloss.append((mse_loss.item(), gso_loss.item()))
                evalloss1.append((mse_loss1.item(), gso_loss1.item()))
                evalloss2.append((mse_loss2.item(), gso_loss2.item()))
                evalloss3.append((mse_loss3.item(), gso_loss3.item()))
                
        tloss_tot, tloss_mse, tloss_gso = map(lambda x: sum(x)/len(x), zip(*trainloss))
        tloss_mse1, tloss_gso1 = map(lambda x: sum(x)/len(x), zip(*trainloss1))
        tloss_mse2, tloss_gso2 = map(lambda x: sum(x)/len(x), zip(*trainloss2))
        tloss_mse3, tloss_gso3 = map(lambda x: sum(x)/len(x), zip(*trainloss3))
        eloss_mse, eloss_gso   = map(lambda x: sum(x)/len(x), zip(*evalloss))
        eloss_mse1, eloss_gso1 = map(lambda x: sum(x)/len(x), zip(*evalloss1))
        eloss_mse2, eloss_gso2 = map(lambda x: sum(x)/len(x), zip(*evalloss2))
        eloss_mse3, eloss_gso3 = map(lambda x: sum(x)/len(x), zip(*evalloss3))

        # save model if improved        
        endp1 = endp2 = endp3 = "\n" 
        if eloss_mse1 <= min(elog_mse1, default=eloss_mse1):
            model1.save(SAVE_DIR, f"{PRENAME_MSE}_1"); endp1=" *mse1\n"
        if eloss_gso1 <= min(elog_gso1, default=eloss_gso1):
            model1.save(SAVE_DIR, f"{PRENAME_GSO}_1"); endp1=" *gso1\n"
        if eloss_mse2 <= min(elog_mse2, default=eloss_mse2):
            model2.save(SAVE_DIR, f"{PRENAME_MSE}_2"); endp2=" *mse2\n"
        if eloss_gso2 <= min(elog_gso2, default=eloss_gso2):
            model2.save(SAVE_DIR, f"{PRENAME_GSO}_2"); endp2=" *gso2\n"
        if eloss_mse3 <= min(elog_mse3, default=eloss_mse3):
            model3.save(SAVE_DIR, f"{PRENAME_MSE}_3"); endp3=" *mse3\n"
        if eloss_gso3 <= min(elog_gso3, default=eloss_gso3):
            model3.save(SAVE_DIR, f"{PRENAME_GSO}_3"); endp3=" *gso3\n"
        tlog_mse1.append(tloss_mse1); tlog_gso1.append(tloss_gso1)
        elog_mse1.append(eloss_mse1); elog_gso1.append(eloss_gso1)
        tlog_mse2.append(tloss_mse2); tlog_gso2.append(tloss_gso2)
        elog_mse2.append(eloss_mse2); elog_gso2.append(eloss_gso2)
        tlog_mse3.append(tloss_mse3); tlog_gso3.append(tloss_gso3)
        elog_mse3.append(eloss_mse3); elog_gso3.append(eloss_gso3)
        print(f"{ep+1}/{EPOCHS}\n1: Eval: mse {eloss_mse1:.4f}, gso {eloss_gso1:.4f} | lr:{LEARNING_RATE[ep]:.1e}, r:{GSO_LOSS_RATIO[ep]:.2f} | " + 
            f"{time()-epoch_time:.0f}s, eta:{(time()-start_time)*(EPOCHS-ep)/(ep+1)/60:.0f}m |", end=endp1,  flush=True)
        print(f"2: Eval: mse {eloss_mse2:.4f}, gso {eloss_gso2:.4f} | lr:{LEARNING_RATE[ep]:.1e}, r:{GSO_LOSS_RATIO[ep]:.2f} | " + 
            f"{time()-epoch_time:.0f}s, eta:{(time()-start_time)*(EPOCHS-ep)/(ep+1)/60:.0f}m |", end=endp2,  flush=True)
        print(f"3: Eval: mse {eloss_mse3:.4f}, gso {eloss_gso3:.4f} | lr:{LEARNING_RATE[ep]:.1e}, r:{GSO_LOSS_RATIO[ep]:.2f} | " +
            f"{time()-epoch_time:.0f}s, eta:{(time()-start_time)*(EPOCHS-ep)/(ep+1)/60:.0f}m |", end=endp3,  flush=True)
        if ep >= 10 and (eloss_gso > 30.0 or eloss_mse > 11.0): return False, (), () # stop training, if not converging, try again
    print(f"Training time: {(time()-start_time)/60:.0f}mins")
    print(f"Best losses 1: mse1 {min(elog_mse1):.4f}, gso1 {min(elog_gso1):.4f}")
    print(f"Best losses 2: mse1 {min(elog_mse2):.4f}, gso1 {min(elog_gso2):.4f}")
    print(f"Best losses 3: mse1 {min(elog_mse3):.4f}, gso1 {min(elog_gso3):.4f}")
    for l, n in zip([tlog_mse1, tlog_gso1, tlog_mse2, tlog_gso2, tlog_mse3, tlog_gso3], ["mse1", "gso1", "mse2", "gso2", "mse3", "gso3"]): 
        np.save(f"{SAVE_DIR}/train_{n}_losses.npy", l) # save losses
    for l, n in zip([elog_mse1, elog_gso1, elog_mse2, elog_gso2, elog_mse3, elog_gso3], ["mse1", "gso1", "mse2", "gso2", "mse3", "gso3"]): 
        np.save(f"{SAVE_DIR}/eval_{n}_losses.npy", l) # save losses
    logs = [tlog_mse1, tlog_gso1, tlog_mse2, tlog_gso2, tlog_mse3, tlog_gso3, elog_mse1, elog_gso1, elog_mse2, elog_gso2, elog_mse3, elog_gso3]
    return True, (model1, model2, model3), logs

# train the model (multiple attempts)
for i in range(10): 
    success, (model1, model2, model3), logs = train()
    if success: tlog_mse1,tlog_gso1,tlog_mse2,tlog_gso2,tlog_mse3,tlog_gso3,elog_mse1,elog_gso1,elog_mse2,elog_gso2,elog_mse3,elog_gso3=logs; break
    else: print(f"Convergence failed, retrying... {i+1}/10")
assert success, "Training failed"

In [None]:
# plot losses
all_mse = [tlog_mse1, elog_mse1, tlog_mse2, elog_mse2, tlog_mse3, elog_mse3]
all_gso = [tlog_gso1, elog_gso1, tlog_gso2, elog_gso2, tlog_gso3, elog_gso3]
fig, ax = plt.subplots(3, 2, figsize=(16, 10))
ce, ct = "yellow", "red"
lw = 1.0
ax[0,0].set_title("MSE Loss 1")
ax[0,0].plot(tlog_mse1, color=ct, label="train 1", linewidth=lw)
ax[0,0].plot(elog_mse1, color=ce, label="eval 1", linewidth=lw)
ax[0,1].set_title("GSO Loss 1")
ax[0,1].plot(tlog_gso1, color=ct, label="train 1", linewidth=lw)
ax[0,1].plot(elog_gso1, color=ce, label="eval 1", linewidth=lw)
ax[1,0].set_title("MSE Loss 2")
ax[1,0].plot(tlog_mse2, color=ct, label="train 2", linewidth=lw)
ax[1,0].plot(elog_mse2, color=ce, label="eval 2", linewidth=lw)
ax[1,1].set_title("GSO Loss 2")
ax[1,1].plot(tlog_gso2, color=ct, label="train 2", linewidth=lw)
ax[1,1].plot(elog_gso2, color=ce, label="eval 2", linewidth=lw)
ax[2,0].set_title("MSE Loss 3")
ax[2,0].plot(tlog_mse3, color=ct, label="train 3", linewidth=lw)
ax[2,0].plot(elog_mse3, color=ce, label="eval 3", linewidth=lw)
ax[2,1].set_title("GSO Loss 3")
ax[2,1].plot(tlog_gso3, color=ct, label="train 3", linewidth=lw)
ax[2,1].plot(elog_gso3, color=ce, label="eval 3", linewidth=lw)
for a in ax.flatten(): a.legend(); a.set_xlabel("Epoch"); a.set_ylabel("Loss")
for a in ax[:,0]: a.set_ylim(min([min(x) for x in all_mse]), max([max(x) for x in all_mse]))
for a in ax[:,1]: a.set_ylim(min([min(x) for x in all_gso]), max([max(x) for x in all_gso]))
plt.tight_layout()
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/losses.png")
#now the same but with log scale
fig, ax = plt.subplots(3, 2, figsize=(16, 10))
ax[0,0].set_title("MSE Loss 1 (log)")
ax[0,0].plot(tlog_mse1, color=ct, label="train 1", linewidth=lw)
ax[0,0].plot(elog_mse1, color=ce, label="eval 1", linewidth=lw)
ax[0,1].set_title("GSO Loss 1 (log)")
ax[0,1].plot(tlog_gso1, color=ct, label="train 1", linewidth=lw)
ax[0,1].plot(elog_gso1, color=ce, label="eval 1", linewidth=lw)
ax[1,0].set_title("MSE Loss 2 (log)")
ax[1,0].plot(tlog_mse2, color=ct, label="train 2", linewidth=lw)
ax[1,0].plot(elog_mse2, color=ce, label="eval 2", linewidth=lw)
ax[1,1].set_title("GSO Loss 2 (log)")
ax[1,1].plot(tlog_gso2, color=ct, label="train 2", linewidth=lw)
ax[1,1].plot(elog_gso2, color=ce, label="eval 2", linewidth=lw)
ax[2,0].set_title("MSE Loss 3 (log)")
ax[2,0].plot(tlog_mse3, color=ct, label="train 3", linewidth=lw)
ax[2,0].plot(elog_mse3, color=ce, label="eval 3", linewidth=lw)
ax[2,1].set_title("GSO Loss 3 (log)")
ax[2,1].plot(tlog_gso3, color=ct, label="train 3", linewidth=lw)
ax[2,1].plot(elog_gso3, color=ce, label="eval 3", linewidth=lw)
for a in ax.flatten(): a.legend(); a.set_xlabel("Epoch"); a.set_ylabel("Loss [log]"); a.grid(True, which="both", axis="y"); a.set_yscale("log")
for a in ax[:,0]: a.set_ylim(min([min(x) for x in all_mse]), max([max(x) for x in all_mse]))
for a in ax[:,1]: a.set_ylim(min([min(x) for x in all_gso]), max([max(x) for x in all_gso]))
plt.tight_layout()
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/losses_log.png")


In [None]:
# testing network output
cpu_times1, cpu_times2, dev_times1, dev_times2 = [], [], [], []
for tit, prename in zip(["MSE1", "GSO1", "MSE2", "GSO2"], [f"{PRENAME_MSE}_1", f"{PRENAME_GSO}_1", f"{PRENAME_MSE}_2", f"{PRENAME_GSO}_2"]):
    is_model_1 = "1" in tit
    branch = model1.branch if is_model_1 else model2.branch
    trunk, head = model1.trunk, model1.head
    assert model1.trunk == model2.trunk == model3.trunk
    assert model1.head == model2.head == model3.head
    trunk.load(SAVE_DIR, prename), head.load(SAVE_DIR, prename), branch.load(SAVE_DIR, prename)
    model = PlaNet(branch, trunk, head)
    model.to("cpu")
    model.eval()
    ds = PlaNetDataset(EVAL_DS_PATH)
    # ds = PlaNetDataset(TRAIN_DS_PATH)
    os.makedirs(f"mg_data/{JOBID}/imgs", exist_ok=True)
    N_PLOTS = 2 if HAS_SCREEN else 50
    for i in np.random.randint(0, len(ds), N_PLOTS):  
        fig, axs = plt.subplots(2, 5, figsize=(15, 9))
        curr, mag, prof, psi_ds, rr, zz = [d.to("cpu") for d in ds[i]]
        curr, mag, prof, psi_ds, rr, zz = curr.view(1,-1), mag.view(1,-1), prof.view(1,-1), psi_ds.view(1,1,64,64), rr.view(1,1,64,64), zz.view(1,1,64,64)
        input = torch.cat((curr, mag, prof), 1) if is_model_1 else torch.cat((curr, mag), 1)
        # dev
        model.to(device)
        input, rr, zz = input.to(device), rr.to(device), zz.to(device)
        start = time()
        psi_pred = model((input, rr, zz))
        end = time()
        if is_model_1: dev_times1.append(end-start) 
        else: dev_times2.append(end-start)
        # cpu
        model.to("cpu")
        input, rr, zz = input.to("cpu"), rr.to("cpu"), zz.to("cpu")
        start = time()
        psi_pred = model((input, rr, zz))
        end = time()
        if is_model_1: cpu_times1.append(end-start) 
        else: cpu_times2.append(end-start)

        gso, gso_pred = calc_gso_batch(psi_ds, rr, zz), calc_gso_batch(psi_pred, rr, zz)
        gso, gso_pred = gso.detach().numpy().reshape(64, 64), gso_pred.detach().numpy().reshape(64, 64)
        gso_range = (gso.max(), gso.min())
        gso_levels = np.linspace(gso_range[1], gso_range[0], 12)
        gso_pred = np.clip(gso_pred, gso_range[1], gso_range[0]) # clip to gso range
        
        psi_pred = psi_pred.detach().numpy().reshape(64, 64)
        psi_ds = psi_ds.detach().numpy().reshape(64, 64)
        rr, zz = rr.view(64, 64).detach().numpy(), zz.view(64, 64).detach().numpy()
        ext = [ds.rr.min(), ds.rr.max(), ds.zz.min(), ds.zz.max()]
        bmin, bmax = np.min([psi_ds, psi_pred]), np.max([psi_ds, psi_pred]) # min max psi
        blevels = np.linspace(bmin, bmax, 13, endpoint=True)
        ψ_mse = (psi_ds - psi_pred)**2
        gso_mse = (gso - gso_pred)**2
        mse_levels1 = np.linspace(0, 0.5, 13, endpoint=True)
        mse_levels2 = np.linspace(0, 0.05, 13, endpoint=True)

        im00 = axs[0,0].contourf(rr, zz, psi_ds, blevels, cmap="inferno")
        axs[0,0].set_title("Actual")
        axs[0,0].set_ylabel("ψ")
        fig.colorbar(im00, ax=axs[0,0]) 
        im01 = axs[0,1].contourf(rr, zz, psi_pred, blevels, cmap="inferno")
        axs[0,1].set_title("Predicted")
        fig.colorbar(im01, ax=axs[0,1])
        im02 = axs[0,2].contour(rr, zz, psi_ds, blevels, linestyles='dashed', cmap="inferno")
        axs[0,2].contour(rr, zz, psi_pred, blevels, cmap="inferno")
        axs[0,2].set_title("Contours")
        fig.colorbar(im02, ax=axs[0,2])
        im03 = axs[0,3].contourf(rr, zz, np.clip(ψ_mse, 0, 0.5), mse_levels1, cmap="inferno")
        axs[0,3].set_title("MSE 0.5")
        fig.colorbar(im03, ax=axs[0,3])
        im04 = axs[0,4].contourf(rr, zz, np.clip(ψ_mse, 0.00001, 0.04999), mse_levels2, cmap="inferno")
        axs[0,4].set_title("MSE 0.05")
        fig.colorbar(im04, ax=axs[0,4])
        im10 = axs[1,0].contourf(rr, zz, gso, gso_levels, cmap="inferno")
        axs[1,0].set_ylabel("GSO")
        fig.colorbar(im10, ax=axs[1,0])
        im6 = axs[1,1].contourf(rr, zz, gso_pred, gso_levels, cmap="inferno")
        fig.colorbar(im6, ax=axs[1,1])
        im12 = axs[1,2].contour(rr, zz, gso, gso_levels, linestyles='dashed', cmap="inferno")
        axs[1,2].contour(rr, zz, gso_pred, gso_levels, cmap="inferno")
        fig.colorbar(im12, ax=axs[1,2])
        im13 = axs[1,3].contourf(rr, zz, np.clip(gso_mse, 0, 0.5), mse_levels1, cmap="inferno")
        fig.colorbar(im13, ax=axs[1,3])
        im14 = axs[1,4].contourf(rr, zz, np.clip(gso_mse, 0.00001, 0.04999), mse_levels2, cmap="inferno")
        fig.colorbar(im14, ax=axs[1,4])
        for ax in axs.flatten(): ax.grid(False), ax.set_xticks([]), ax.set_yticks([]), ax.set_aspect("equal")
        plt.suptitle(f"PlaNet: {tit} {i}")
        plt.tight_layout()
        plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/imgs/planet_{tit}_{i}.png")
        plt.close()
cpu_times1, cpu_times2, dev_times1, dev_times2 = map(np.array, [cpu_times1, cpu_times2, dev_times1, dev_times2])
print(f"cpu: inference time 1: {cpu_times1.mean():.5f}s, std: {cpu_times1.std():.5f}")
print(f"cpu: inference time 2: {cpu_times2.mean():.5f}s, std: {cpu_times2.std():.5f}")
print(f"dev: inference time 1: {dev_times1.mean():.5f}s, std: {dev_times1.std():.5f}")
print(f"dev: inference time 3: {dev_times2.mean():.5f}s, std: {dev_times2.std():.5f}")


In [None]:
print("Done", flush=True)
if not HAS_SCREEN: sleep(30) # wait for files to update (for cluster)

In [None]:
#copy the log file to the folder
os.system(f"cp jobs/{JOBID}.txt mg_data/{JOBID}/log.txt")