## Training
Prepapre dataset with the prepare_dataset notebook, before running this one.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module, Linear, Conv2d, MaxPool2d, BatchNorm2d, ReLU, Sequential, ConvTranspose2d
from tqdm import tqdm
from time import time, sleep
import numpy as np
import os
import warnings; warnings.filterwarnings("ignore")
from utils import *
try: 
    JOBID = os.environ["SLURM_JOB_ID"] # get job id from slurm, when training on cluster
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # nvidia
    HAS_SCREEN = False # for plotting or saving images
except:
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") # apple silicon / cpu
    JOBID = "local"
    HAS_SCREEN = True
SAVE_DIR = f"data/{JOBID}"  
# device = torch.device('cpu') # for debugging, use cpu
os.makedirs(f"{SAVE_DIR}", exist_ok=True)
os.makedirs(f"{SAVE_DIR}/models", exist_ok=True)
print(f'device: {device}, has_screen: {HAS_SCREEN}, job id: {JOBID}')

# copy the python training to the directory (for cluster) (for local, it fails silently)
os.system(f"cp train.py {SAVE_DIR}/train.py")
os.system(f"cp utils.py {SAVE_DIR}/utils.py")

def to_tensor(x, device=torch.device("cpu")): return torch.tensor(x, dtype=torch.float32, device=device)

# SMALL, NORM, BIG = "small", "norm", "big"

In [None]:

EPOCHS = 100 # number of epochs # 1000
BATCH_SIZE = 128 # 128 <-

# MODEL_NAME = SMALL
# MODEL_NAME = NORM
# MODEL_NAME = BIG

LOAD_PRETRAINED = None # Set it to None if you don't want to load pretrained model
# LOAD_PRETRAINED = "trained_models/pretrained_1809761.pth" # norm model
# LOAD_PRETRAINED = "trained_models/pretrained_small_1810888.pth" # small model
# LOAD_PRETRAINED = "trained_models/pretrained_big_1811142.pth" # big model

# LEARNING_RATE = 3e-4*np.logspace(0, -2, EPOCHS) 
# LEARNING_RATE = 3e-3*np.ones(EPOCHS) 
# LEARNING_RATE = 3e-3*np.logspace(0, -2, EPOCHS)  # <-
LEARNING_RATE = 1e-3*np.logspace(0, -2, EPOCHS) 

# GSO_LOSS_RATIO = np.concatenate((np.linspace(1e-6, 3e-3, EPOCHS//2), np.linspace(3e-3, 0.0, EPOCHS//2))) 
MAX_GSO = 5e-3 # 3e-3 <-
GSO_LOSS_RATIO = np.concatenate((np.linspace(1e-6, MAX_GSO, EPOCHS//4), 
                                 np.linspace(MAX_GSO, MAX_GSO, EPOCHS//4), 
                                 np.linspace(MAX_GSO, 0.0, EPOCHS//4), 
                                 np.linspace(0.0, 0.0, EPOCHS//4))) 
# GSO_LOSS_RATIO = np.concatenate((MAX_GSO*np.logspace(-6, 0, EPOCHS//4), 
#                                  MAX_GSO*np.logspace(0, 0, EPOCHS//4), 
#                                  MAX_GSO*np.logspace(0, -10, EPOCHS//4), 
#                                  np.logspace(-12, -12, EPOCHS//4))) 
# GSO_LOSS_RATIO = np.zeros(EPOCHS) 

TRAIN_DS_PATH = "dss/train_ds.npz" # generated from prepapre_dataset
EVAL_DS_PATH = "dss/eval_ds.npz"

In [None]:
#checks
if LOAD_PRETRAINED is not None: assert os.path.exists(LOAD_PRETRAINED), "Pretrained model does not exist"
assert os.path.exists(TRAIN_DS_PATH), "Training dataset does not exist"
assert os.path.exists(EVAL_DS_PATH), "Evaluation dataset does not exist"
assert os.path.exists(SAVE_DIR), "Save directory does not exist"
assert len(LEARNING_RATE) == EPOCHS, "Learning rate array length does not match epochs"
assert len(GSO_LOSS_RATIO) == EPOCHS, "GSO loss ratio array length does not match epochs"

In [None]:
# plot schedulers: lr + gso loss ratio
fig, ax = plt.subplots(1, 2, figsize=(12, 3))
ax[0].set_title("Learning Rate [Log]")
ax[0].plot(LEARNING_RATE, color="red")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Learning Rate")
ax[0].set_yscale("log")
ax[1].set_title("GSO Loss Ratio")
ax[1].plot(GSO_LOSS_RATIO, color="red")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("GSO Loss Ratio")
plt.tight_layout()
plt.show() if HAS_SCREEN else plt.savefig(f"{SAVE_DIR}/schedulers.png")

In [None]:
class PlaNetDataset(Dataset):
    def __init__(self, ds_mat_path):
        d = np.load(ds_mat_path)
        # output: magnetic flux, transposed (matlab is column-major)
        self.X =  to_tensor(d["X"]) # (n, NIN) # inputs: currents + measurements + profiles
        self.Y =  to_tensor(d["Y"]).view(-1,1,NGZ,NGR)
        self.r = to_tensor(d["r"]).view(-1,1,NGZ,NGR) # radial position of pixels 
        self.z = to_tensor(d["z"]).view(-1,1,NGZ,NGR) # vertical position of pixels 
        #move to device (doable bc the dataset is fairly small, check memory usage)
        self.Y, self.X, self.r, self.z = self.Y.to(device), self.X.to(device), self.r.to(device), self.z.to(device)
        total_memory = sum([x.element_size()*x.nelement() for x in [self.Y, self.X, self.r, self.z]])
        print(f"Dataset: {len(self)}, memory: {total_memory/1024**2:.0f} MB")
    def __len__(self): return len(self.Y)
    def __getitem__(self, idx): return self.X[idx], self.Y[idx], self.r[idx], self.z[idx]

In [None]:
# test dataset
ds = PlaNetDataset(EVAL_DS_PATH)
print(f"Dataset length: {len(ds)}")
print(f"Input shape: {ds[0][0].shape}")
print(f"Output shape: {ds[0][1].shape}")
n_plot = 10
print(len(ds))
idxs = np.random.randint(0, len(ds), n_plot)
fig, axs = plt.subplots(1, n_plot, figsize=(3*n_plot, 5))
for i, j in enumerate(idxs):
    Y, rr, zz = ds[j][1].cpu().numpy().squeeze(), ds[j][2].cpu().numpy().squeeze(), ds[j][3].cpu().numpy().squeeze()
    axs[i].contourf(rr, zz, Y, 100, cmap="inferno")
    axs[i].plot(VESS[:,0], VESS[:,1], color="white", linewidth=2)
    # axs[i].contour(rr, zz, -Y, 20, colors="black", linestyles="dotted")
    fig.colorbar(axs[i].collections[0], ax=axs[i])
    axs[i].axis("off")
    axs[i].set_aspect("equal")
plt.show() if HAS_SCREEN else plt.savefig(f"{SAVE_DIR}/dataset.png")

# now do the same fot the input:
fig, axs = plt.subplots(1, n_plot, figsize=(3*n_plot, 5))
for i, j in enumerate(idxs):
    inputs = ds[j][0].cpu().numpy().squeeze()
    if USE_CURRENTS: axs[i].plot(inputs[:19], label="currents")
    if USE_MAGNETIC: axs[i].plot(inputs[19:57], label="magnetic")
    if USE_PROFILES: axs[i].plot(inputs[57:], label="profiles")
    axs[i].legend()
    axs[i].set_title(f"Sample {j}")
    axs[i].set_xlabel("Input index")
plt.show() if HAS_SCREEN else plt.savefig(f"{SAVE_DIR}/dataset_inputs.png")

In [None]:
# create training and evaluation datasets
train_ds = PlaNetDataset(TRAIN_DS_PATH)
val_ds = PlaNetDataset(EVAL_DS_PATH)

In [None]:
# activation functions
# custom trainable swish
class Swish(Module):
    def __init__(self, β=1.0): 
        super(Swish, self).__init__()
        # self.β = torch.nn.Parameter(torch.tensor(β, device=device), requires_grad=True)
        self.β = torch.nn.Parameter(torch.tensor(β), requires_grad=True)
    def forward(self, x): 
        return x*torch.sigmoid(self.β*x)
    def to(self, device): 
        self.β = self.β.to(device)
        return super().to(device)

# Λ = ReLU() # ReLU activation function
# Λ = Swish() # Swish activation function

# class Λ(Module): # relu
#     def __init__(self): super(Λ, self).__init__()
#     def forward(self, x): return torch.relu(x)

class Λ(Module): # swish
    def __init__(self): 
        super(Λ, self).__init__()
        self.β = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
    def forward(self, x): return x*torch.sigmoid(self.β*x)

In [None]:
class EasyPlaNet(Module): # Paper net: branch + trunk conenction and everything
    def __init__(self, input_size=NIN, latent_size=32, grid_size=(NGZ,NGR)):
        super(EasyPlaNet, self).__init__()
        assert latent_size % 2 == 0, "latent size should be even"
        self.input_size, self.latent_size, self.grid_size = input_size, latent_size, grid_size
        self.fgs = grid_size[0]*grid_size[1] # fgs: flattened grid size
        #branch
        self.branch = Sequential(
            View(-1, input_size),
            Linear(input_size, 64), Λ(),
            Linear(64, 32), Λ(),
            Linear(32, latent_size), Λ(),
        )
        #trunk
        def trunk_block(): 
            return  Sequential(
                View(-1, self.fgs),
                Linear(self.fgs, 32), Λ(),
                Linear(32, latent_size//2), Λ(),
            )
        self.trunk_r, self.trunk_z = trunk_block(), trunk_block()
        # head
        self.head = Sequential(
            Linear(latent_size, 64), Λ(),
            Linear(64, self.fgs), Λ(),
            View(-1, 1, *self.grid_size),
        )
    def forward(self, xb, r, z):
        assert xb.shape[1] == self.input_size, f"branch input shape {xb.shape} != {self.input_size}"
        #branch net
        xb = self.branch(xb)
        assert xb.shape[1] == self.latent_size, f"branch output shape {xb.shape} != {self.latent_size}"
        #trunk net
        r, z = self.trunk_r(r), self.trunk_z(z) 
        xt = torch.cat((r, z), 1) # concatenate
        assert xt.shape[1] == self.latent_size, f"trunk output shape {xt.shape} != {self.latent_size}"
        x = xt * xb # multiply trunk and branch
        x = self.head(x) # head net
        return x

In [None]:
# test model inputs / outputs
x, rr, zz = (torch.rand(1, NIN), torch.rand(1, 1, NGZ, NGR), torch.rand(1, 1, NGZ, NGR))
net = EasyPlaNet()
y = net(x, rr, zz)
print(f"in: {x.shape}, {rr.shape}, {zz.shape}, \nout: {y.shape}")
n_sampl = 7
nx, rr, zz = torch.rand(n_sampl, NIN), torch.rand(n_sampl, 1, NGZ, NGR), torch.rand(n_sampl, 1, NGZ, NGR)
ny = net(nx, rr, zz)
print(f"in: {nx.shape}, {rr.shape}, {zz.shape}, \nout: {ny.shape}")
assert ny.shape == (n_sampl, 1, NGZ, NGR), f"Wrong output shape: {ny.shape}"

## Training

In [None]:
def train():
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) # initialize DataLoader
    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)  
    model = EasyPlaNet()  # instantiate model
    if LOAD_PRETRAINED is not None: # load pretrained model
        model.load_state_dict(torch.load(LOAD_PRETRAINED, map_location=torch.device("cpu"))) # load pretrained model
        print(f"Pretrained model loaded: {LOAD_PRETRAINED}")
    model.to(device) # move model to device
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE[0])
    loss_fn = torch.nn.MSELoss() # Mean Squared Error Loss
    tlog_tot, tlog_mse, tlog_gso, elog_tot, elog_mse, elog_gso = [], [], [], [], [], [] # logs for losses
    start_time = time() # start time
    for ep in range(EPOCHS): # epochs
        epoch_time = time()
        for pg in optimizer.param_groups: pg['lr'] = LEARNING_RATE[ep] # update learning rate
        model.train()
        trainloss, evalloss = [], []
        for input, Y, rr, zz in train_dl:
            optimizer.zero_grad() # zero gradients
            psi_pred = model(input, rr, zz) # forward pass
            gso, gso_pred = calc_gso_batch(Y, rr, zz, dev=device), calc_gso_batch(psi_pred, rr, zz, dev=device) # calculate grad shafranov
            mse_loss = loss_fn(psi_pred, Y) # mean squared error loss on Y
            gso_loss = loss_fn(gso_pred, gso) # PINN loss on grad shafranov
            loss = (1-GSO_LOSS_RATIO[ep])*mse_loss + GSO_LOSS_RATIO[ep]*gso_loss # total loss
            loss.backward() # backprop
            optimizer.step() # update weights 
            trainloss.append((loss.item(), mse_loss.item(), gso_loss.item())) # save batch losses
        model.eval() # evaluation mode
        with torch.no_grad():
            for input, Y, rr, zz in val_dl:
                psi_pred = model(input, rr, zz)
                gso, gso_pred = calc_gso_batch(Y, rr, zz, dev=device), calc_gso_batch(psi_pred, rr, zz, dev=device)
                mse_loss = loss_fn(psi_pred, Y)
                gso_loss = loss_fn(gso_pred, gso)
                loss = (1-GSO_LOSS_RATIO[ep])*mse_loss + GSO_LOSS_RATIO[ep]*gso_loss # total loss
                evalloss.append((loss.item(), mse_loss.item(), gso_loss.item()))
        tloss_tot, tloss_mse, tloss_gso = map(lambda x: sum(x)/len(x), zip(*trainloss))
        eloss_tot, eloss_mse, eloss_gso = map(lambda x: sum(x)/len(x), zip(*evalloss))
        # save model if improved        
        endp = "\n" 
        if eloss_tot <= min(elog_tot, default=eloss_tot): 
            torch.save(model.state_dict(), f"{SAVE_DIR}/mg_planet_tot.pth"); endp=" [tot]\n"
        if eloss_mse <= min(elog_mse, default=eloss_mse):
            torch.save(model.state_dict(), f"{SAVE_DIR}/mg_planet_mse.pth"); endp=" [mse]\n"
        if eloss_gso <= min(elog_gso, default=eloss_gso):
            torch.save(model.state_dict(), f"{SAVE_DIR}/mg_planet_gso.pth"); endp=" [gso]\n"
        tlog_tot.append(tloss_tot); tlog_mse.append(tloss_mse); tlog_gso.append(tloss_gso)
        elog_tot.append(eloss_tot); elog_mse.append(eloss_mse); elog_gso.append(eloss_gso) 
        print(f"[{ep+1}/{EPOCHS}] "
            f"Eval -> tot {eloss_tot:.4f}, mse {eloss_mse:.4f}, gso {eloss_gso:.2e}, " + 
            f"lr {LEARNING_RATE[ep]:.1e}, r {GSO_LOSS_RATIO[ep]:.1e}, {time()-epoch_time:.0f}s, eta {(time()-start_time)*(EPOCHS-ep)/(ep+1)/60:.0f}m", end=endp,  flush=True)
        if ep >= 10 and ((eloss_gso > 30.0 and GSO_LOSS_RATIO[ep] > 0.01) or eloss_mse > .2): return False, () # stop training, if not converging, try again
    print(f"Training time: {(time()-start_time)/60:.0f}mins")
    print(f"Best losses: tot {min(elog_tot):.3e}, mse {min(elog_mse):.3e}, gso {min(elog_gso):.4f}")
    print(f"Estimated MAE: tot {np.sqrt(min(elog_tot)):.3e}, mae {np.sqrt(min(elog_mse)):.3e}, gso {np.sqrt(min(elog_gso)):.4f}")
    for l, n in zip([tlog_tot, tlog_mse, tlog_gso], ["tot", "mse", "gso"]): np.save(f"{SAVE_DIR}/train_{n}_losses.npy", l) # save losses
    for l, n in zip([elog_tot, elog_mse, elog_gso], ["tot", "mse", "gso"]): np.save(f"{SAVE_DIR}/eval_{n}_losses.npy", l) # save losses
    return True, (tlog_tot, tlog_mse, tlog_gso, elog_tot, elog_mse, elog_gso)

# train the model (multiple attempts)
for i in range(20): 
    success, logs = train()
    if success: tlog_tot, tlog_mse, tlog_gso, elog_tot, elog_mse, elog_gso = logs; break
    else: print(f"Convergence failed, retrying... {i+1}/20")
# if not success: delete the files and exit:
if not success: 
    os.system(f"rm -rf {SAVE_DIR}")
    assert success, "Training failed, no model saved"

In [None]:
# plot losses
fig, ax = plt.subplots(2, 3, figsize=(12, 6))
ce, ct = "yellow", "red"
lw = 1.0
ax[0,0].set_title("TOT Loss")
ax[0,0].plot(tlog_tot, color=ct, label="train", linewidth=lw)
ax[0,0].plot(elog_tot, color=ce, label="eval", linewidth=lw)
ax[0,1].set_title("MSE Loss")
ax[0,1].plot(tlog_mse, color=ct, label="train", linewidth=lw)
ax[0,1].plot(elog_mse, color=ce, label="eval", linewidth=lw)
ax[0,2].set_title("GSO Loss")
ax[0,2].plot(tlog_gso, color=ct, label="train", linewidth=lw)
ax[0,2].plot(elog_gso, color=ce, label="eval", linewidth=lw)
#now the same but with log scale
ax[1,0].set_title("TOT Loss (log)")
ax[1,0].plot(tlog_tot, color=ct, label="train", linewidth=lw)
ax[1,0].plot(elog_tot, color=ce, label="eval", linewidth=lw)
ax[1,0].set_yscale("log")
ax[1,0].grid(True, which="both", axis="y")

ax[1,1].set_title("MSE Loss (log)")
ax[1,1].plot(tlog_mse, color=ct, label="train", linewidth=lw)
ax[1,1].plot(elog_mse, color=ce, label="eval", linewidth=lw)
ax[1,1].set_yscale("log")
ax[1,1].grid(True, which="both", axis="y")

ax[1,2].set_title("GSO Loss (log)")
ax[1,2].plot(tlog_gso, color=ct, label="train", linewidth=lw)
ax[1,2].plot(elog_gso, color=ce, label="eval", linewidth=lw)
ax[1,2].set_yscale("log")
ax[1,2].grid(True, which="both", axis="y")
plt.suptitle(f"[{JOBID}] Training losses")
for a in ax.flatten(): a.legend(); a.set_xlabel("Epoch"); a.set_ylabel("Loss")
plt.tight_layout()
plt.show() if HAS_SCREEN else plt.savefig(f"{SAVE_DIR}/losses.png")

In [None]:
# testing network output
for titl, best_model_path in zip(["TOT","MSE", "GSO"], ["mg_planet_tot.pth", "mg_planet_mse.pth", "mg_planet_gso.pth"]):
    model = EasyPlaNet()
    model.load_state_dict(torch.load(f"{SAVE_DIR}/{best_model_path}"))
    model.eval()
    ds = val_ds
    os.makedirs(f"{SAVE_DIR}/imgs", exist_ok=True)
    N_PLOTS = 2 if HAS_SCREEN else 50
    for i in np.random.randint(0, len(ds), N_PLOTS):  
        fig, axs = plt.subplots(2, 5, figsize=(15, 9))
        input, psi_ds, rr, zz = ds[i]
        input, psi_ds, rr, zz = input.to('cpu'), psi_ds.to('cpu'), rr.to('cpu'), zz.to('cpu')
        input, psi_ds, rr, zz = input.reshape(1,-1), psi_ds.reshape(1,1,NGZ,NGR), rr.reshape(1,1,NGZ,NGR), zz.reshape(1,1,NGZ,NGR)
        psi_pred = model(input, rr, zz)
        gso, gso_pred = calc_gso_batch(psi_ds, rr, zz), calc_gso_batch(psi_pred, rr, zz)
        gso, gso_pred = gso.detach().numpy().reshape(NGZ,NGR), gso_pred.detach().numpy().reshape(NGZ,NGR)
        gso_min, gso_max = np.min([gso, gso_pred]), np.max([gso, gso_pred])
        gso_levels = np.linspace(gso_min, gso_max, 13, endpoint=True)
        # gso_pred = np.clip(gso_pred, gso_range[1], gso_range[0]) # clip to gso range
        
        psi_pred = psi_pred.detach().numpy().reshape(NGZ,NGR)
        psi_ds = psi_ds.detach().numpy().reshape(NGZ,NGR)
        rr, zz = rr.view(NGZ,NGR).detach().numpy(), zz.view(NGZ,NGR).detach().numpy()
        ext = [ds.r.min(), ds.r.max(), ds.z.min(), ds.z.max()]
        bmin, bmax = np.min([psi_ds, psi_pred]), np.max([psi_ds, psi_pred]) # min max Y
        blevels = np.linspace(bmin, bmax, 13, endpoint=True)
        # ψ_msex = (psi_ds - psi_pred)**2
        # gso_msex = (gso - gso_pred)**2
        ψ_mae = np.abs(psi_ds - psi_pred)
        gso_mae = np.abs(gso - gso_pred)
        lev0 = np.linspace(0, 5.0, 13, endpoint=True)
        lev1 = np.linspace(0, 0.5, 13, endpoint=True) 
        lev2 = np.linspace(0, 0.05, 13, endpoint=True)
        lev3 = np.linspace(0, 0.005, 13, endpoint=True)
        ε = 1e-12

        im00 = axs[0,0].contourf(rr, zz, psi_ds, blevels, cmap="inferno")
        axs[0,0].set_title("Actual")
        axs[0,0].set_aspect('equal')
        axs[0,0].set_ylabel("ψ")
        fig.colorbar(im00, ax=axs[0,0]) 
        im01 = axs[0,1].contourf(rr, zz, psi_pred, blevels, cmap="inferno")
        axs[0,1].set_title("Predicted")
        fig.colorbar(im01, ax=axs[0,1])
        im02 = axs[0,2].contour(rr, zz, psi_ds, blevels, linestyles='dashed', cmap="inferno")
        axs[0,2].contour(rr, zz, psi_pred, blevels, cmap="inferno")
        axs[0,2].set_title("Contours")
        fig.colorbar(im02, ax=axs[0,2])
        im03 = axs[0,3].contourf(rr, zz, np.clip(ψ_mae, lev2[0]+ε, lev2[-1]-ε), lev2, cmap="inferno")
        axs[0,3].set_title("MAE 0.05")
        fig.colorbar(im03, ax=axs[0,3])
        im04 = axs[0,4].contourf(rr, zz, np.clip(ψ_mae, lev3[0]+ε, lev3[-1]-ε), lev3, cmap="inferno")
        axs[0,4].set_title("MAE 0.005")
        fig.colorbar(im04, ax=axs[0,4])

        im10 = axs[1,0].contourf(rr, zz, gso, gso_levels, cmap="inferno")
        axs[1,0].set_ylabel("GSO")
        fig.colorbar(im10, ax=axs[1,0])
        im11 = axs[1,1].contourf(rr, zz, gso_pred, gso_levels, cmap="inferno")
        fig.colorbar(im11, ax=axs[1,1])
        im12 = axs[1,2].contour(rr, zz, gso, gso_levels, linestyles='dashed', cmap="inferno")
        axs[1,2].contour(rr, zz, gso_pred, gso_levels, cmap="inferno")
        fig.colorbar(im12, ax=axs[1,2])
        im13 = axs[1,3].contourf(rr, zz, np.clip(gso_mae, lev0[0]+ε, lev0[-1]-ε), lev0, cmap="inferno")
        fig.colorbar(im13, ax=axs[1,3])
        im14 = axs[1,4].contourf(rr, zz, np.clip(gso_mae, lev1[0]+ε, lev1[-1]-ε), lev1, cmap="inferno")
        fig.colorbar(im14, ax=axs[1,4])

        for ax in axs.flatten(): 
            ax.grid(False), ax.set_xticks([]), ax.set_yticks([]), ax.set_aspect("equal")
            ax.plot(VESS[:,0], VESS[:,1], color="white", linewidth=2)

        #suptitle
        plt.suptitle(f"[{JOBID}] EasyPlaNet: {titl} {i}")

        plt.tight_layout()
        plt.show() if HAS_SCREEN else plt.savefig(f"{SAVE_DIR}/imgs/planet_{titl}_{i}.png")
        
        plt.close()

In [None]:
# test inference speed
model = EasyPlaNet()
model.load_state_dict(torch.load(f"{SAVE_DIR}/{best_model_path}"))
model.eval()
ds = PlaNetDataset(EVAL_DS_PATH)
n_samples = 100
random_idxs = np.random.choice(n_samples, len(ds))
#cpu
￼
cpu_times1, cpu_times2 = [], []
for i in random_idxs:
    start_t = time()
    input, psi_ds, rr, zz = ds[i]
    input, psi_ds, rr, zz = input.to('cpu'), psi_ds.to('cpu'), rr.to('cpu'), zz.to('cpu')
    input, psi_ds, rr, zz = input.view(1,-1), psi_ds.view(1,1,NGZ,NGR), rr.view(1,1,NGZ,NGR), zz.view(1,1,NGZ,NGR)
    start_t2 = time()
    psi_pred = model(input, rr, zz)
    end_t = time()
    cpu_times1.append(end_t - start_t); cpu_times2.append(end_t - start_t2) 
# device
model.to(device)
dev_times1, dev_times2 = [], []
for i in random_idxs:
    input, psi_ds, rr, zz = ds[i]
    input, psi_ds, rr, zz = input.view(1,-1), psi_ds.view(1,1,NGZ,NGR), rr.view(1,1,NGZ,NGR), zz.view(1,1,NGZ,NGR)
    start_t = time()
    start_t2 = time()
    psi_pred = model(input, rr, zz)
    end_t = time()
    dev_times1.append(end_t - start_t); dev_times2.append(end_t - start_t2)    
cpu_times1, dev_times1 = np.array(cpu_times1)*1000, np.array(dev_times1)*1000
cpu_times2, dev_times2 = np.array(cpu_times2)*1000, np.array(dev_times2)*1000
print(f"cpu: inference time: [full -> {cpu_times1.mean():.5f}ms, std: {cpu_times1.std():.5f}]")
print(f"cpu: inference time: [inference only -> {cpu_times2.mean():.5f}ms, std: {cpu_times2.std():.5f}]")
print(f"dev: inference time: [full -> {dev_times1.mean():.5f}ms, std: {dev_times1.std():.5f}]")
print(f"dev: inference time: [inference only -> {dev_times2.mean():.5f}ms, std: {dev_times2.std():.5f}]")

In [None]:
print(f"{JOBID} done", flush=True)
if not HAS_SCREEN: sleep(30) # wait for files to update (for cluster)

In [None]:
#copy the log file to the folder
os.system(f"cp jobs/{JOBID}.txt {SAVE_DIR}/log.txt")