# Train

### Ideas
- Use a network specific for treating missing values
- predict the clean input
- train together with the clean input
- train together with

In [None]:
from utils import *
from common import GRID_SIZE, INPUT_SIZE, VDI_INTERVAL, VDC_INTERVAL, VDE_INTERVAL, HOR1_INTERVAL

In [2]:
# parameters
N_DS = 50_000 # number of samples in the dataset
TRAIN_DS_PATH = f"data/sxr_ds_{N_DS}.npz"
EVAL_DS_PATH = f"data/sxr_ds_{N_DS//10}.npz"
BATCH_SIZE = 1 #128 # NOTE: batch size 1 works best
EPOCHS = 6 # 10
LEARNING_RATE = np.ones(EPOCHS) * 3e-4 # learning rate

REAL_DATASET = True # use real dataset
# REAL_DATASET = False

# architecture
# ARCHITECTURE = SXRNetU32
# ARCHITECTURE = SXRNetU32Big
# ARCHITECTURE = SXRNetU64
ARCHITECTURE = SXRNetLinear1
# ARCHITECTURE = SXRNetLinear2

NOISE_LEVEL = 0.0 #0.05 # noise level [fraction on the mean]
RANDOM_REMOVE = 0 #3 # number of random sensors to remove each time

N_PLOTS = 6 if HAS_SCREEN else 50
LOAD_PRETRAINED = None #SAVE_DIR + "/mg_tomo_best.pth" # pretrained model path
SAVE_PATH = SAVE_DIR + "/mg_tomo.pth" # save model path

In [None]:
# test dataset
ds = SXRDataset(N_DS//10, REAL_DATASET, NOISE_LEVEL, RANDOM_REMOVE)
print(f"Dataset length: {len(ds)}")
print(f"Input shape: {ds[0][0].shape}")
print(f"Output shape: {ds[0][1].shape}")
n_plot = 10
print(len(ds))
fig, axs = plt.subplots(2, n_plot, figsize=(3*n_plot, 5))
for i, j in enumerate(np.random.randint(0, len(ds), n_plot)):
    sxr, em = ds[j][0].cpu().numpy().squeeze(), ds[j][1].cpu().numpy()
    axs[0,i].contourf(ds.RRL, ds.ZZL, em, 100, cmap="inferno")
    axs[0,i].axis("off")
    axs[0,i].set_aspect("equal")
    #plot sxr
    axs[1,i].plot(sxr, 'rs')
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/dataset.png")

In [None]:
# test model
model = ARCHITECTURE(INPUT_SIZE)
input = torch.randn(1, INPUT_SIZE)
output = model(input)
print(f'Input: {input.shape} Output: {output.shape}')

In [None]:
# training
def train():
    train_ds, val_ds = SXRDataset(N_DS, REAL_DATASET, NOISE_LEVEL, RANDOM_REMOVE), SXRDataset(N_DS//10, REAL_DATASET, NOISE_LEVEL, RANDOM_REMOVE) # initialize datasets
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) # initialize DataLoader
    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)  
    model = ARCHITECTURE(INPUT_SIZE)  # instantiate model
    if LOAD_PRETRAINED is not None: # load pretrained model
        model.load_state_dict(torch.load(LOAD_PRETRAINED, map_location=torch.device("cpu"))) # load pretrained model
        print(f"Pretrained model loaded: {LOAD_PRETRAINED}")
        torch.save(model.state_dict(), SAVE_PATH)
    model.to(DEV) # move model to DEV
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE[0])
    loss_fn = torch.nn.MSELoss() # loss function
    tlog_tot, elog_tot = [], []# logs for losses
    start_time = time() # start time
    for ep in range(EPOCHS): # epochs
        epoch_time = time()
        for pg in optimizer.param_groups: pg['lr'] = LEARNING_RATE[ep] # update learning rate
        model.train()
        trainloss, evalloss = torch.zeros(len(train_dl)), torch.zeros(len(val_dl)) # initialize losses
        batches = tqdm(train_dl, desc=f"Epoch {ep+1}/{EPOCHS}", leave=False) if HAS_SCREEN else train_dl
        # batches = train_dl
        for bi, (sxr, em) in enumerate(batches):
            optimizer.zero_grad() # zero gradients
            em_pred = model(sxr) # forward pass
            loss = loss_fn(em_pred, em) # mean squared error loss on em
            loss.backward() # backprop
            optimizer.step() # update weights
            trainloss[bi] = loss.item() # save loss
        model.eval() # evaluation mode
        with torch.no_grad():
            for bi, (sxr, em) in enumerate(val_dl):
                em_pred = model(sxr)
                loss = loss_fn(em_pred, em)
                evalloss[bi] = loss.item()
        tloss_tot = trainloss.mean().item() # total training loss
        eloss_tot = evalloss.mean().item() # total evaluation loss
        # save model if improved        
        if eloss_tot <= min(elog_tot, default=eloss_tot): 
            torch.save(model.state_dict(), SAVE_PATH); endp=" *\n"
        else: endp = "\n"
        tlog_tot.append(tloss_tot) 
        elog_tot.append(eloss_tot)
        print(f"{ep+1}/{EPOCHS}: "
            f"Eval: loss {eloss_tot:.6f} | " + 
            f"lr:{LEARNING_RATE[ep]:.1e} | " +
            f"{time()-epoch_time:.0f}s, eta:{(time()-start_time)*(EPOCHS-ep)/(ep+1)/60:.0f}m |", end=endp,  flush=True)
        if ep >= 10 and eloss_tot > 9.0: return False, () # stop training, if not converging, try again
    print(f"Training time: {(time()-start_time)/60:.0f}mins")
    print(f"Best losses: tot {min(elog_tot):.4f}")
    for l, n in zip([tlog_tot], ["tot"]): np.save(f"{SAVE_DIR}/train_{n}_losses.npy", l) # save losses
    for l, n in zip([elog_tot], ["tot"]): np.save(f"{SAVE_DIR}/eval_{n}_losses.npy", l) # save losses
    return True, (tlog_tot, elog_tot)

# train the model (multiple attempts)
for i in range(10): 
    success, logs = train()
    if success: tlog_tot, elog_tot = logs; break
    else: print(f"Convergence failed, retrying... {i+1}/10")
assert success, "Training failed"

In [None]:
# plot losses
fig, ax = plt.subplots(1, 2, figsize=(10, 3))
ce, ct = "yellow", "red"
lw = 1.0
ax[0].set_title("TOT Loss")
ax[0].plot(tlog_tot, color=ct, label="train", linewidth=lw)
ax[0].plot(elog_tot, color=ce, label="eval", linewidth=lw)

#now the same but with log scale
ax[1].set_title("TOT Loss (log)")
ax[1].plot(tlog_tot, color=ct, label="train", linewidth=lw)
ax[1].plot(elog_tot, color=ce, label="eval", linewidth=lw)
ax[1].set_yscale("log")
ax[1].grid(True, which="both", axis="y")

for a in ax.flatten(): a.legend(); a.set_xlabel("Epoch"); a.set_ylabel("Loss")
plt.tight_layout()
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/losses.png")

In [None]:
# testing clean data network output
model = ARCHITECTURE(INPUT_SIZE)
model.load_state_dict(torch.load(SAVE_PATH, map_location=torch.device("cpu")))
model.eval()
ds = SXRDataset(N_DS//10, REAL_DATASET)
rr, zz = ds.RRL, ds.ZZL # grid coordinates
vdi0, vdi1, vdc0, vdc1, vde0, vde1, hor0, hor1 = VDI_INTERVAL[0], VDI_INTERVAL[1], VDC_INTERVAL[0], VDC_INTERVAL[1], VDE_INTERVAL[0], VDE_INTERVAL[1], HOR1_INTERVAL[0], HOR1_INTERVAL[1]
for i in np.random.randint(0, len(ds), N_PLOTS):  
    sxr, em_ds = ds[i]
    em = em_ds.detach().cpu().numpy().reshape(GRID_SIZE, GRID_SIZE)
    sxr, em_ds = sxr.to('cpu'), em_ds.to('cpu')
    sxr, em_ds = sxr.view(1,-1), em_ds.view(1,1,GRID_SIZE,GRID_SIZE)
    # clean data
    em_pred = model(sxr).detach().numpy().reshape(GRID_SIZE, GRID_SIZE)
    vdi, vdc, vde, hor = sxr[0,vdi0:vdi1], sxr[0,vdc0:vdc1], sxr[0,vde0:vde1], sxr[0,hor0:hor1]
    plot_net_example(em, em_pred, [vdi, vdc, vde, hor], rr, zz, f"CLEAN {i}")

In [None]:
# testing noisy network output
model = ARCHITECTURE(INPUT_SIZE)
model.load_state_dict(torch.load(SAVE_PATH, map_location=torch.device("cpu")))
model.eval()
ds = SXRDataset(N_DS//10, REAL_DATASET)
rr, zz = ds.RRL, ds.ZZL # grid coordinates
vdi0, vdi1, vdc0, vdc1, vde0, vde1, hor0, hor1 = VDI_INTERVAL[0], VDI_INTERVAL[1], VDC_INTERVAL[0], VDC_INTERVAL[1], VDE_INTERVAL[0], VDE_INTERVAL[1], HOR1_INTERVAL[0], HOR1_INTERVAL[1]
for i in np.random.randint(0, len(ds), N_PLOTS):  
    sxr, em_ds = ds[i]
    em = em_ds.detach().numpy().reshape(GRID_SIZE, GRID_SIZE)
    sxr, em_ds = sxr.to('cpu'), em_ds.to('cpu')
    sxr, em_ds = sxr.view(1,-1), em_ds.view(1,1,GRID_SIZE,GRID_SIZE)
    # clean data
    em_pred = model(sxr).detach().numpy().reshape(GRID_SIZE, GRID_SIZE)
    vdi, vdc, vde, hor = sxr[0,vdi0:vdi1], sxr[0,vdc0:vdc1], sxr[0,vde0:vde1], sxr[0,hor0:hor1]
    plot_net_example(em, em_pred, [vdi, vdc, vde, hor], rr, zz, f"CLEAN {i}")
    # noisy data
    noise = torch.randn_like(sxr) * 0.05 * torch.max(sxr)
    noisy_sxr = sxr + noise
    em_pred = model(noisy_sxr).detach().numpy().reshape(GRID_SIZE, GRID_SIZE)
    vdi, vdc, vde, hor = noisy_sxr[0,vdi0:vdi1], noisy_sxr[0,vdc0:vdc1], noisy_sxr[0,vde0:vde1], noisy_sxr[0,hor0:hor1]
    plot_net_example(em, em_pred, [vdi, vdc, vde, hor], rr, zz, f"NOISY {i}")
    # noisy data with random sensors removed
    noisy_sxr[0, np.random.randint(0, INPUT_SIZE, RANDOM_REMOVE)] = 0 # remove some sensors
    em_pred = model(noisy_sxr).detach().numpy().reshape(GRID_SIZE, GRID_SIZE)
    vdi, vdc, vde, hor = noisy_sxr[0,vdi0:vdi1], noisy_sxr[0,vdc0:vdc1], noisy_sxr[0,vde0:vde1], noisy_sxr[0,hor0:hor1]
    plot_net_example(em, em_pred, [vdi, vdc, vde, hor], rr, zz, f"NOISY-REMOVED {i}")
    print(f"------------------------------------------------------------------------------------------")

In [None]:
# test inference speed
model = ARCHITECTURE(INPUT_SIZE)
model.load_state_dict(torch.load(SAVE_PATH, map_location=torch.device("cpu")))
model.eval()
ds = SXRDataset(N_DS//10, REAL_DATASET, 0.0)
n_samples = 100
random_idxs = np.random.choice(n_samples, len(ds))
#cpu
cpu_times = []
for i in random_idxs:
    start_t = time()
    sxr, em_ds = ds[i]
    sxr, em_ds = sxr.to('cpu'), em_ds.to('cpu')
    sxr, em_ds = sxr.view(1,-1), em_ds.view(1,1, GRID_SIZE, GRID_SIZE)
    em_pred = model(sxr)
    end_t = time()
    cpu_times.append(end_t - start_t) 
# DEV
model.to(DEV)
dev_times = []
for i in random_idxs:
    sxr, em_ds = ds[i]
    sxr, em_ds = sxr.view(1,-1), em_ds.view(1,1, GRID_SIZE, GRID_SIZE)
    start_t = time()
    em_pred = model(sxr)
    end_t = time()
    dev_times.append(end_t - start_t)    
cpu_times, dev_times = np.array(cpu_times), np.array(dev_times)
print(f"cpu: inference time: {1000*cpu_times.mean():.3f}[ms], std: {1000*cpu_times.std():.3f}[ms]")
print(f"dev: inference time: {1000*dev_times.mean():.3f}[ms], std: {1000*dev_times.std():.3f}[ms]")