In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module, Linear, Conv2d, MaxPool2d, BatchNorm2d, ReLU, Sequential, ConvTranspose2d
import scipy.io as sio
from time import time, sleep
import numpy as np
from numpy import pi as π
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")  # adds seaborn style to charts, eg. grid
plt.style.use("dark_background")  # inverts colors to dark theme
plt.rcParams['font.family'] = 'monospace'
np.set_printoptions(precision=3) # set precision for printing numpy arrays
import os
import warnings; warnings.filterwarnings("ignore")
try: 
    JOBID = os.environ["SLURM_JOB_ID"] # get job id from slurm, when training on cluster
    DEV = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # nvidia
    HAS_SCREEN = False # for plotting or saving images
except:
    DEV = torch.device("mps") # apple silicon
    JOBID = "local"
    HAS_SCREEN = True
os.makedirs(f"mg_data/{JOBID}", exist_ok=True)

# DEV = torch.device("cpu") # cpu
print(f'DEV: {DEV}')

def to_tensor(x, dev=torch.device("cpu")): return torch.tensor(x, dtype=torch.float32, device=dev)

In [None]:
# parameters
GRID_SIZE = 32 # size of the grid
SXRV_SIZE = 21 #21 # number of vertical sensors (match with dataset)
SXRH_SIZE = 23 #23 # number of horizontal sensors (match with dataset)
INPUT_SIZE = SXRH_SIZE + SXRV_SIZE
DATASET_SIZE = 10_000 # number of samples in the dataset
TRAIN_DS_PATH = f"data/sxr_ds_{DATASET_SIZE}.npz"
EVAL_DS_PATH = f"data/sxr_ds_{DATASET_SIZE//10}.npz"
# TRAIN_DS_PATH = "data/sxr_ds_100000.npz"
# EVAL_DS_PATH = "data/sxr_ds_10000.npz"
BATCH_SIZE = 1 #128 # NOTE: batch size 1 works best
LOAD_PRETRAINED = None
EPOCHS = 10
LEARNING_RATE = np.ones(EPOCHS) * 3e-4 # learning rate

SAVE_DIR = f"mg_data/{JOBID}/lin" 
os.makedirs(SAVE_DIR, exist_ok=True)
# copy the python training to the directory (for cluster) (for local, it fails silently)
os.system(f"cp train.py {SAVE_DIR}/train.py")

In [None]:
# dataset
class SXRDataset(Dataset):
    def __init__(self, ds_path):
        ds = np.load(ds_path)
        # soft x-ray horizontal and vertical sensors
        self.sxr = to_tensor(np.concatenate([ds['sxrh'], ds['sxrv']], axis=-1), DEV)
        self.em = to_tensor(ds['emiss_lr'], DEV) # emissivities (NxN)
        self.RR, self.ZZ, self.rr, self.zz = ds['RR'], ds['ZZ'], ds['rr'], ds['zz'] # grid coordinates
        assert len(self.em) == len(self.sxr), f'length mismatch: {len(self.em)} vs {len(self.sxr)}'
        assert self.sxr.shape[-1] == INPUT_SIZE, f'sxr size mismatch: {self.sxr.shape[-1]}'
        assert self.rr.shape[0] == GRID_SIZE, f'grid size ({self.RR.shape[0]}) is wrong, match it with dataset generator'
    def __len__(self): return len(self.sxr)
    def __getitem__(self, idx):
        return self.sxr[idx], self.em[idx]

# test dataset
ds = SXRDataset('data/sxr_ds_1000.npz')
print(f'ds len: {len(ds)}')

In [None]:
# test dataset
ds = SXRDataset(EVAL_DS_PATH)
print(f"Dataset length: {len(ds)}")
print(f"Input shape: {ds[0][0].shape}")
print(f"Output shape: {ds[0][1].shape}")
n_plot = 10
print(len(ds))
fig, axs = plt.subplots(2, n_plot, figsize=(3*n_plot, 5))
for i, j in enumerate(np.random.randint(0, len(ds), n_plot)):
    sxr, em = ds[j][0].cpu().numpy().squeeze(), ds[j][1].cpu().numpy().squeeze()
    axs[0,i].contourf(ds.rr, ds.zz, em, 100, cmap="inferno")
    axs[0,i].contour(ds.rr, ds.zz, -em, 20, colors="black", linestyles="dotted")
    axs[0,i].axis("off")
    axs[0,i].set_aspect("equal")
    #plot sxr
    axs[1,i].plot(sxr)
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/dataset.png")

In [5]:
# activation functions
class Swish(Module): # custom trainable swish
    def __init__(self, β=1.0): 
        super(Swish, self).__init__()
        self.β = torch.nn.Parameter(torch.tensor(β), requires_grad=True)
    def forward(self, x): 
        return x*torch.sigmoid(self.β*x)
    def to(self, device): 
        self.β = self.β.to(device)
        return super().to(device)

# class Λ(Module): # relu, uncomment for relu activation
#     def __init__(self): super(Λ, self).__init__()
#     def forward(self, x): return torch.relu(x)

class Λ(Module): # swish, uncomment for swish activation
    def __init__(self): 
        super(Λ, self).__init__()
        self.β = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
    def forward(self, x): return x*torch.sigmoid(self.β*x)

In [6]:
# network architecture (1 layer fully connected)
WIDTH = 1024*8
class SXRNet(Module):
    def __init__(self):
        super(SXRNet, self).__init__()
        self.net = Sequential(
            Linear(INPUT_SIZE, WIDTH), Λ(),
            Linear(WIDTH, GRID_SIZE*GRID_SIZE), Λ(),
        )
    def forward(self, x):
        x = self.net(x)
        x = x.view(-1, 1, GRID_SIZE, GRID_SIZE)
        return x

In [None]:
# test model
model = SXRNet()
input = torch.randn(1, INPUT_SIZE)
output = model(input)
print(f'Input: {input.shape} Output: {output.shape}')

In [None]:
def train():
    train_ds, val_ds = SXRDataset(TRAIN_DS_PATH), SXRDataset(EVAL_DS_PATH) # initialize datasets
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) # initialize DataLoader
    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)  
    model = SXRNet()  # instantiate model
    if LOAD_PRETRAINED is not None: # load pretrained model
        model.load_state_dict(torch.load(LOAD_PRETRAINED, map_location=torch.device("cpu"))) # load pretrained model
        print(f"Pretrained model loaded: {LOAD_PRETRAINED}")
    model.to(DEV) # move model to DEV
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE[0])
    loss_fn = torch.nn.MSELoss() # Mean Squared Error Loss
    tlog_tot, elog_tot = [], []# logs for losses
    start_time = time() # start time
    for ep in range(EPOCHS): # epochs
        epoch_time = time()
        for pg in optimizer.param_groups: pg['lr'] = LEARNING_RATE[ep] # update learning rate
        model.train()
        trainloss, evalloss = [], []
        for sxr, em in train_dl:
            optimizer.zero_grad() # zero gradients
            em_pred = model(sxr) # forward pass
            loss = loss_fn(em_pred, em) # mean squared error loss on em
            loss.backward() # backprop
            optimizer.step() # update weights
            trainloss.append((loss.item())) # save batch losses
        model.eval() # evaluation mode
        with torch.no_grad():
            for sxr, em in val_dl:
                em_pred = model(sxr)
                loss = loss_fn(em_pred, em)
                evalloss.append((loss.item()))
        tloss_tot = sum(trainloss)/len(trainloss) # average train loss
        eloss_tot = sum(evalloss)/len(evalloss) # average eval loss
        # save model if improved        
        if eloss_tot <= min(elog_tot, default=eloss_tot): 
            torch.save(model.state_dict(), f"{SAVE_DIR}/mg_planet_tot.pth"); endp=" *\n"
        else: endp = "\n"
        tlog_tot.append(tloss_tot) 
        elog_tot.append(eloss_tot)
        print(f"{ep+1}/{EPOCHS}: "
            f"Eval: loss {eloss_tot:.4f} | " + 
            f"lr:{LEARNING_RATE[ep]:.1e} | " + 
            f"{time()-epoch_time:.0f}s, eta:{(time()-start_time)*(EPOCHS-ep)/(ep+1)/60:.0f}m |", end=endp,  flush=True)
        if ep >= 10 and eloss_tot > 9.0: return False, () # stop training, if not converging, try again
    print(f"Training time: {(time()-start_time)/60:.0f}mins")
    print(f"Best losses: tot {min(elog_tot):.4f}")
    for l, n in zip([tlog_tot], ["tot"]): np.save(f"{SAVE_DIR}/train_{n}_losses.npy", l) # save losses
    for l, n in zip([elog_tot], ["tot"]): np.save(f"{SAVE_DIR}/eval_{n}_losses.npy", l) # save losses
    return True, (tlog_tot, elog_tot)

# train the model (multiple attempts)
for i in range(10): 
    success, logs = train()
    if success: tlog_tot, elog_tot = logs; break
    else: print(f"Convergence failed, retrying... {i+1}/10")
assert success, "Training failed"

In [None]:
# plot losses
fig, ax = plt.subplots(1, 2, figsize=(10, 3))
ce, ct = "yellow", "red"
lw = 1.0
ax[0].set_title("TOT Loss")
ax[0].plot(tlog_tot, color=ct, label="train", linewidth=lw)
ax[0].plot(elog_tot, color=ce, label="eval", linewidth=lw)

#now the same but with log scale
ax[1].set_title("TOT Loss (log)")
ax[1].plot(tlog_tot, color=ct, label="train", linewidth=lw)
ax[1].plot(elog_tot, color=ce, label="eval", linewidth=lw)
ax[1].set_yscale("log")
ax[1].grid(True, which="both", axis="y")

for a in ax.flatten(): a.legend(); a.set_xlabel("Epoch"); a.set_ylabel("Loss")
plt.tight_layout()
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/losses.png")

In [None]:
# testing network output
for titl, best_model_path in zip(["TOT"], ["mg_planet_tot.pth"]):
    model = SXRNet()
    model.load_state_dict(torch.load(f"{SAVE_DIR}/{best_model_path}"))
    model.eval()
    ds = SXRDataset(EVAL_DS_PATH)
    # ds = SXRDataset(TRAIN_DS_PATH)
    rr, zz = ds.rr, ds.zz # grid coordinates
    os.makedirs(f"mg_data/{JOBID}/imgs", exist_ok=True)
    N_PLOTS = 12 if HAS_SCREEN else 50
    for i in np.random.randint(0, len(ds), N_PLOTS):  
        fig, axs = plt.subplots(1, 5, figsize=(20, 4))
        sxr, em_ds = ds[i]
        sxr, em_ds = sxr.to('cpu'), em_ds.to('cpu')
        sxr, em_ds = sxr.view(1,-1), em_ds.view(1,1,GRID_SIZE,GRID_SIZE)
        em_pred = model(sxr)
        
        em_pred = em_pred.detach().numpy().reshape(GRID_SIZE, GRID_SIZE)
        em_ds = em_ds.detach().numpy().reshape(GRID_SIZE, GRID_SIZE)
        bmin, bmax = np.min([em_ds, em_pred]), np.max([em_ds, em_pred]) # min max em
        blevels = np.linspace(bmin, bmax, 13, endpoint=True)
        em_mse = (em_ds - em_pred)**2
        mse_levels1 = np.linspace(0, 0.5, 13, endpoint=True)
        mse_levels2 = np.linspace(0, 0.05, 13, endpoint=True)

        im00 = axs[0].contourf(rr, zz, em_ds, blevels, cmap="inferno")
        axs[0].set_title("Actual")
        axs[0].set_aspect('equal')
        axs[0].set_ylabel("em")
        fig.colorbar(im00, ax=axs[0]) 
        im01 = axs[1].contourf(rr, zz, em_pred, blevels, cmap="inferno")
        axs[1].set_title("Predicted")
        fig.colorbar(im01, ax=axs[1])
        im02 = axs[2].contour(rr, zz, em_ds, blevels, linestyles='dashed', cmap="inferno")
        axs[2].contour(rr, zz, em_pred, blevels, cmap="inferno")
        axs[2].set_title("Contours")
        fig.colorbar(im02, ax=axs[2])
        im03 = axs[3].contourf(rr, zz, np.clip(em_mse, 0, 0.5), mse_levels1, cmap="inferno")
        axs[3].set_title("MSE 0.5")
        fig.colorbar(im03, ax=axs[3])
        im04 = axs[4].contourf(rr, zz, np.clip(em_mse, 0.00001, 0.04999), mse_levels2, cmap="inferno")
        axs[4].set_title("MSE 0.05")
        fig.colorbar(im04, ax=axs[4])
        for ax in axs.flatten(): ax.grid(False), ax.set_xticks([]), ax.set_yticks([]), ax.set_aspect("equal")

        #suptitle
        plt.suptitle(f"SXRNet: {titl} {i}")

        plt.tight_layout()
        os.makedirs(f"{SAVE_DIR}/imgs", exist_ok=True)
        plt.show() if HAS_SCREEN else plt.savefig(f"{SAVE_DIR}/imgs/planet_{titl}_{i}.png")
        plt.close()

In [None]:
# test inference speed
model = SXRNet()
model.load_state_dict(torch.load(f"{SAVE_DIR}/{best_model_path}"))
model.eval()
ds = SXRDataset(EVAL_DS_PATH)
n_samples = 100
random_idxs = np.random.choice(n_samples, len(ds))
#cpu
cpu_times = []
for i in random_idxs:
    start_t = time()
    sxr, em_ds = ds[i]
    sxr, em_ds = sxr.to('cpu'), em_ds.to('cpu')
    sxr, em_ds = sxr.view(1,-1), em_ds.view(1,1, GRID_SIZE, GRID_SIZE)
    em_pred = model(sxr)
    end_t = time()
    cpu_times.append(end_t - start_t) 
# DEV
model.to(DEV)
dev_times = []
for i in random_idxs:
    sxr, em_ds = ds[i]
    sxr, em_ds = sxr.view(1,-1), em_ds.view(1,1, GRID_SIZE, GRID_SIZE)
    start_t = time()
    em_pred = model(sxr)
    end_t = time()
    dev_times.append(end_t - start_t)    
cpu_times, dev_times = np.array(cpu_times), np.array(dev_times)
print(f"cpu: inference time: {cpu_times.mean():.5f}s, std: {cpu_times.std():.5f}")
print(f"dev: inference time: {dev_times.mean():.5f}s, std: {dev_times.std():.5f}")