In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module, Linear, Conv2d, MaxPool2d, BatchNorm2d, ReLU, Sequential, ConvTranspose2d
from torchvision import transforms
import torchvision
import scipy.io as sio
from time import time, sleep
import numpy as np
from numpy import pi as π
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")  # adds seaborn style to charts, eg. grid
plt.style.use("dark_background")  # inverts colors to dark theme
plt.rcParams['font.family'] = 'monospace'
np.set_printoptions(precision=3) # set precision for printing numpy arrays
import os
import warnings; warnings.filterwarnings("ignore")
try: 
    JOBID = os.environ["SLURM_JOB_ID"] # get job id from slurm, when training on cluster
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # nvidia
    HAS_SCREEN = False # for plotting or saving images
except:
    device = torch.device("mps") # apple silicon
    JOBID = "local"
    HAS_SCREEN = True
os.makedirs(f"mg_data/{JOBID}", exist_ok=True)
print(f'device: {device}')

def to_tensor(x, device=torch.device("cpu")): return torch.tensor(x, dtype=torch.float32, device=device)

In [None]:
# parameters
SXRV_SIZE = 21 #21 # number of vertical sensors (match with dataset)
SXRH_SIZE = 23 #23 # number of horizontal sensors (match with dataset)
INPUT_SIZE = SXRH_SIZE + SXRV_SIZE
# KHR = 8 # multiplier for high resolution
# SXR_HR_SIZE = INPUT_SIZE*KHR # high resolution size
TRAIN_DS_PATH = "data/sxr_sg_ds_10000.npz"
EVAL_DS_PATH = "data/sxr_sg_ds_1000.npz"
# TRAIN_DS_PATH = "data/sxr_sg_ds_100000.npz"
# EVAL_DS_PATH = "data/sxr_sg_ds_10000.npz"
BATCH_SIZE = 1
LOAD_PRETRAINED = None
EPOCHS = 10
LEARNING_RATE = np.ones(EPOCHS) * 3e-4 # learning rate

SAVE_DIR = f"mg_data/{JOBID}/sg" 
os.makedirs(SAVE_DIR, exist_ok=True)
# copy the python training to the directory (for cluster) (for local, it fails silently)
os.system(f"cp train.py {SAVE_DIR}/train.py")

In [None]:
# dataset
class SXRDataset(Dataset):
    def __init__(self, ds_path):
        ds = np.load(ds_path)
        # self.emiss = to_tensor(ds['emiss'], device)
        sxrh = ds['sxrh'] # soft x ray horizontal
        sxrv = ds['sxrv'] # soft x ray vertical
        self.sxr = to_tensor(np.concatenate([sxrh, sxrv], axis=-1), device)
        self.rr = to_tensor(ds['rr'], device).view(-1,1,64,64)
        self.zz = to_tensor(ds['zz'], device).view(-1,1,64,64)
        self.em = to_tensor(ds['emiss_sg'], device)
        assert len(self.em) == len(self.sxr), f'length mismatch: {len(self.em)} vs {len(self.sxr)}'
        assert self.sxr.shape[-1] == INPUT_SIZE, f'sxr size mismatch: {self.sxr.shape[-1]}'
        # print(f"dataset: emiss: {self.emiss.shape}, sxr: {self.sxr.shape}, sxr_hr: {self.sxr_hr.shape}")
    def __len__(self): return len(self.sxr)
    def __getitem__(self, idx):
        return self.sxr[idx], self.em[idx], self.rr[idx], self.zz[idx]

# test dataset
ds = SXRDataset('data/sxr_sg_ds_1000.npz')
print(f'ds len: {len(ds)}')

In [None]:
# test dataset
ds = SXRDataset(EVAL_DS_PATH)
print(f"Dataset length: {len(ds)}")
print(f"Input shape: {ds[0][0].shape}")
print(f"Output shape: {ds[0][1].shape}")
n_plot = 10
print(len(ds))
fig, axs = plt.subplots(2, n_plot, figsize=(3*n_plot, 5))
for i, j in enumerate(np.random.randint(0, len(ds), n_plot)):
    sxr, em, rr, zz = ds[j][0].cpu().numpy().squeeze(), ds[j][1].cpu().numpy().squeeze(), ds[j][2].cpu().numpy().squeeze(), ds[j][3].cpu().numpy().squeeze()
    axs[0,i].contourf(rr, zz, em, 100, cmap="inferno")
    axs[0,i].contour(rr, zz, -em, 20, colors="black", linestyles="dotted")
    axs[0,i].axis("off")
    axs[0,i].set_aspect("equal")
    #plot sxr
    axs[1,i].plot(sxr)
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/dataset.png")

In [5]:
# activation functions
class Swish(Module): # custom trainable swish
    def __init__(self, β=1.0): 
        super(Swish, self).__init__()
        self.β = torch.nn.Parameter(torch.tensor(β), requires_grad=True)
    def forward(self, x): 
        return x*torch.sigmoid(self.β*x)
    def to(self, device): 
        self.β = self.β.to(device)
        return super().to(device)

# class Λ(Module): # relu, uncomment for relu activation
#     def __init__(self): super(Λ, self).__init__()
#     def forward(self, x): return torch.relu(x)

class Λ(Module): # swish, uncomment for swish activation
    def __init__(self): 
        super(Λ, self).__init__()
        self.β = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
    def forward(self, x): return x*torch.sigmoid(self.β*x)

In [6]:
# MODEL: VERY BIG SXRNet: # SAME NET, BUT MORE NEURONS
class SXRNet(Module): # Paper net: branch + trunk conenction and everything
    def __init__(self):
        super(SXRNet, self).__init__()
        #branch
        self.branch = Sequential(
            Linear(INPUT_SIZE, 256), Λ(),
            Linear(256, 128), Λ(),
            Linear(128, 64), Λ(),
        )
        #trunk
        def trunk_block(): 
            return  Sequential(
                Conv2d(1, 8, kernel_size=3, stride=1, padding=1), BatchNorm2d(8), Λ(), MaxPool2d(2),
                Conv2d(8, 16, kernel_size=3, stride=1, padding=1), BatchNorm2d(16), Λ(), MaxPool2d(2),
                Conv2d(16, 32, kernel_size=3, stride=1, padding=1), BatchNorm2d(32), Λ(), MaxPool2d(2),
            )
        self.trunk_r, self.trunk_z = trunk_block(), trunk_block()
        self.trunk_fc = Sequential(
            Linear(2*32*8*8, 128), Λ(),
            Linear(128, 64), Λ(),
            Linear(64, 64), Λ(), 
        )
        # head
        self.fc = Sequential(Linear(64, 4096), Λ())
        self.anti_conv = Sequential( # U-Net style
            ConvTranspose2d(64, 64, kernel_size=2, stride=2), 
            Conv2d(64, 64, kernel_size=3, padding=0), Λ(),
            Conv2d(64, 64, kernel_size=3, padding=0), Λ(),
            ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            Conv2d(32, 32, kernel_size=3, padding=0), Λ(),
            Conv2d(32, 32, kernel_size=3, padding=0), Λ(),
            ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            Conv2d(16, 16, kernel_size=3, padding=0), Λ(),
            Conv2d(16, 16, kernel_size=3, padding=0), Λ(),
            ConvTranspose2d(16, 8, kernel_size=2, stride=2),
            Conv2d(8, 4, kernel_size=3, padding=0), Λ(),
            Conv2d(4, 2, kernel_size=3, padding=0), Λ(),
            Conv2d(2, 1, kernel_size=5, padding=0),
        )
    def forward(self, x):
        xb, r, z = x
        #branch net
        xb = self.branch(xb)
        #trunk net
        r, z = self.trunk_r(r), self.trunk_z(z) # convolutions
        r, z = r.view(-1, 32*8*8), z.view(-1, 32*8*8) # flatten
        xt = torch.cat((r, z), 1) # concatenate
        xt = self.trunk_fc(xt) # fully connected
        # multiply trunk and branch
        x = xt * xb
        #head net
        x = self.fc(x)
        x = x.view(-1, 64, 8, 8)
        x = self.anti_conv(x)
        return x

In [None]:
# test model
model = SXRNet()
input = (torch.randn(1, INPUT_SIZE), torch.randn(1, 1, 64, 64), torch.randn(1, 1, 64, 64))
output = model(input)
print(f'Input: {input[0].shape}, {input[1].shape}, {input[2].shape}, \nOutput: {output.shape}')

In [None]:
def train():
    train_ds, val_ds = SXRDataset(TRAIN_DS_PATH), SXRDataset(EVAL_DS_PATH) # initialize datasets
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) # initialize DataLoader
    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)  
    model = SXRNet()  # instantiate model
    if LOAD_PRETRAINED is not None: # load pretrained model
        model.load_state_dict(torch.load(LOAD_PRETRAINED, map_location=torch.device("cpu"))) # load pretrained model
        print(f"Pretrained model loaded: {LOAD_PRETRAINED}")
    model.to(device) # move model to device
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE[0])
    loss_fn = torch.nn.MSELoss() # Mean Squared Error Loss
    tlog_tot, elog_tot = [], []# logs for losses
    start_time = time() # start time
    for ep in range(EPOCHS): # epochs
        epoch_time = time()
        for pg in optimizer.param_groups: pg['lr'] = LEARNING_RATE[ep] # update learning rate
        model.train()
        trainloss, evalloss = [], []
        for sxr, em, rr, zz in train_dl:
            optimizer.zero_grad() # zero gradients
            em_pred = model((sxr, rr, zz)) # forward pass
            loss = loss_fn(em_pred, em) # mean squared error loss on em
            loss.backward() # backprop
            optimizer.step() # update weights
            trainloss.append((loss.item())) # save batch losses
        model.eval() # evaluation mode
        with torch.no_grad():
            for sxr, em, rr, zz in val_dl:
                em_pred = model((sxr, rr, zz))
                loss = loss_fn(em_pred, em)
                evalloss.append((loss.item()))
        tloss_tot = sum(trainloss)/len(trainloss) # average train loss
        eloss_tot = sum(evalloss)/len(evalloss) # average eval loss
        # save model if improved        
        if eloss_tot <= min(elog_tot, default=eloss_tot): 
            torch.save(model.state_dict(), f"{SAVE_DIR}/mg_planet_tot.pth"); endp=" *\n"
        else: endp = "\n"
        tlog_tot.append(tloss_tot) 
        elog_tot.append(eloss_tot)
        print(f"{ep+1}/{EPOCHS}: "
            f"Eval: loss {eloss_tot:.4f} | " + 
            f"lr:{LEARNING_RATE[ep]:.1e} | " + 
            f"{time()-epoch_time:.0f}s, eta:{(time()-start_time)*(EPOCHS-ep)/(ep+1)/60:.0f}m |", end=endp,  flush=True)
        if ep >= 10 and eloss_tot > 9.0: return False, () # stop training, if not converging, try again
    print(f"Training time: {(time()-start_time)/60:.0f}mins")
    print(f"Best losses: tot {min(elog_tot):.4f}")
    for l, n in zip([tlog_tot], ["tot"]): np.save(f"{SAVE_DIR}/train_{n}_losses.npy", l) # save losses
    for l, n in zip([elog_tot], ["tot"]): np.save(f"{SAVE_DIR}/eval_{n}_losses.npy", l) # save losses
    return True, (tlog_tot, elog_tot)

# train the model (multiple attempts)
for i in range(10): 
    success, logs = train()
    if success: tlog_tot, elog_tot = logs; break
    else: print(f"Convergence failed, retrying... {i+1}/10")
assert success, "Training failed"

In [None]:
# plot losses
fig, ax = plt.subplots(2, 3, figsize=(12, 6))
ce, ct = "yellow", "red"
lw = 1.0
ax[0,0].set_title("TOT Loss")
ax[0,0].plot(tlog_tot, color=ct, label="train", linewidth=lw)
ax[0,0].plot(elog_tot, color=ce, label="eval", linewidth=lw)
# ax[0,1].set_title("MSE Loss")
# ax[0,1].plot(tlog_mse, color=ct, label="train", linewidth=lw)
# ax[0,1].plot(elog_mse, color=ce, label="eval", linewidth=lw)
# ax[0,2].set_title("GSO Loss")
# ax[0,2].plot(tlog_gso, color=ct, label="train", linewidth=lw)
# ax[0,2].plot(elog_gso, color=ce, label="eval", linewidth=lw)

#now the same but with log scale
ax[1,0].set_title("TOT Loss (log)")
ax[1,0].plot(tlog_tot, color=ct, label="train", linewidth=lw)
ax[1,0].plot(elog_tot, color=ce, label="eval", linewidth=lw)
ax[1,0].set_yscale("log")
ax[1,0].grid(True, which="both", axis="y")

# ax[1,1].set_title("MSE Loss (log)")
# ax[1,1].plot(tlog_mse, color=ct, label="train", linewidth=lw)
# ax[1,1].plot(elog_mse, color=ce, label="eval", linewidth=lw)
# ax[1,1].set_yscale("log")
# ax[1,1].grid(True, which="both", axis="y")

# ax[1,2].set_title("GSO Loss (log)")
# ax[1,2].plot(tlog_gso, color=ct, label="train", linewidth=lw)
# ax[1,2].plot(elog_gso, color=ce, label="eval", linewidth=lw)
# ax[1,2].set_yscale("log")
# ax[1,2].grid(True, which="both", axis="y")

for a in ax.flatten(): a.legend(); a.set_xlabel("Epoch"); a.set_ylabel("Loss")
plt.tight_layout()
plt.show() if HAS_SCREEN else plt.savefig(f"mg_data/{JOBID}/losses.png")

In [None]:
# testing network output
for titl, best_model_path in zip(["TOT"], ["mg_planet_tot.pth"]):
    model = SXRNet()
    model.load_state_dict(torch.load(f"{SAVE_DIR}/{best_model_path}"))
    model.eval()
    ds = SXRDataset(EVAL_DS_PATH)
    # ds = SXRDataset(TRAIN_DS_PATH)
    os.makedirs(f"mg_data/{JOBID}/imgs", exist_ok=True)
    N_PLOTS = 2 if HAS_SCREEN else 50
    for i in np.random.randint(0, len(ds), N_PLOTS):  
        fig, axs = plt.subplots(2, 5, figsize=(15, 9))
        sxr, em_ds, rr, zz = ds[i]
        sxr, em_ds, rr, zz = sxr.to('cpu'), em_ds.to('cpu'), rr.to('cpu'), zz.to('cpu')
        sxr, em_ds, rr, zz = sxr.view(1,-1), em_ds.view(1,1,64,64), rr.view(1,1,64,64), zz.view(1,1,64,64)
        em_pred = model((sxr, rr, zz))
        
        em_pred = em_pred.detach().numpy().reshape(64, 64)
        em_ds = em_ds.detach().numpy().reshape(64, 64)
        rr, zz = rr.view(64, 64).detach().numpy(), zz.view(64, 64).detach().numpy()
        ext = [ds.rr.min(), ds.rr.max(), ds.zz.min(), ds.zz.max()]
        bmin, bmax = np.min([em_ds, em_pred]), np.max([em_ds, em_pred]) # min max em
        blevels = np.linspace(bmin, bmax, 13, endpoint=True)
        em_mse = (em_ds - em_pred)**2
        mse_levels1 = np.linspace(0, 0.5, 13, endpoint=True)
        mse_levels2 = np.linspace(0, 0.05, 13, endpoint=True)

        im00 = axs[0,0].contourf(rr, zz, em_ds, blevels, cmap="inferno")
        axs[0,0].set_title("Actual")
        axs[0,0].set_aspect('equal')
        axs[0,0].set_ylabel("em")
        fig.colorbar(im00, ax=axs[0,0]) 
        im01 = axs[0,1].contourf(rr, zz, em_pred, blevels, cmap="inferno")
        axs[0,1].set_title("Predicted")
        fig.colorbar(im01, ax=axs[0,1])
        im02 = axs[0,2].contour(rr, zz, em_ds, blevels, linestyles='dashed', cmap="inferno")
        axs[0,2].contour(rr, zz, em_pred, blevels, cmap="inferno")
        axs[0,2].set_title("Contours")
        fig.colorbar(im02, ax=axs[0,2])
        im03 = axs[0,3].contourf(rr, zz, np.clip(em_mse, 0, 0.5), mse_levels1, cmap="inferno")
        axs[0,3].set_title("MSE 0.5")
        fig.colorbar(im03, ax=axs[0,3])
        im04 = axs[0,4].contourf(rr, zz, np.clip(em_mse, 0.00001, 0.04999), mse_levels2, cmap="inferno")
        axs[0,4].set_title("MSE 0.05")
        fig.colorbar(im04, ax=axs[0,4])
        # im10 = axs[1,0].contourf(rr, zz, gso, gso_levels, cmap="inferno")
        axs[1,0].set_ylabel("GSO")
        # fig.colorbar(im10, ax=axs[1,0])
        # im6 = axs[1,1].contourf(rr, zz, gso_pred, gso_levels, cmap="inferno")
        # fig.colorbar(im6, ax=axs[1,1])
        # im12 = axs[1,2].contour(rr, zz, gso, gso_levels, linestyles='dashed', cmap="inferno")
        # axs[1,2].contour(rr, zz, gso_pred, gso_levels, cmap="inferno")
        # fig.colorbar(im12, ax=axs[1,2])
        # im13 = axs[1,3].contourf(rr, zz, np.clip(gso_mse, 0, 0.5), mse_levels1, cmap="inferno")
        # fig.colorbar(im13, ax=axs[1,3])
        # im14 = axs[1,4].contourf(rr, zz, np.clip(gso_mse, 0.00001, 0.04999), mse_levels2, cmap="inferno")
        # fig.colorbar(im14, ax=axs[1,4])

        for ax in axs.flatten(): ax.grid(False), ax.set_xticks([]), ax.set_yticks([]), ax.set_aspect("equal")

        #suptitle
        plt.suptitle(f"SXRNet: {titl} {i}")

        plt.tight_layout()
        os.makedirs(f"{SAVE_DIR}/imgs", exist_ok=True)
        plt.show() if HAS_SCREEN else plt.savefig(f"{SAVE_DIR}/imgs/{titl}_{i}.png")
        plt.close()

In [None]:
# test inference speed
model = SXRNet()
model.load_state_dict(torch.load(f"{SAVE_DIR}/{best_model_path}"))
model.eval()
ds = SXRDataset(EVAL_DS_PATH)
n_samples = 100
random_idxs = np.random.choice(n_samples, len(ds))
#cpu
cpu_times = []
for i in random_idxs:
    start_t = time()
    sxr, em_ds, rr, zz = ds[i]
    sxr, em_ds, rr, zz = sxr.to('cpu'), em_ds.to('cpu'), rr.to('cpu'), zz.to('cpu')
    sxr, em_ds, rr, zz = sxr.view(1,-1), em_ds.view(1,1,64,64), rr.view(1,1,64,64), zz.view(1,1,64,64)
    em_pred = model((sxr, rr, zz))
    end_t = time()
    cpu_times.append(end_t - start_t) 
# device
model.to(device)
dev_times = []
for i in random_idxs:
    sxr, em_ds, rr, zz = ds[i]
    sxr, em_ds, rr, zz = sxr.view(1,-1), em_ds.view(1,1,64,64), rr.view(1,1,64,64), zz.view(1,1,64,64)
    start_t = time()
    em_pred = model((sxr, rr, zz))
    end_t = time()
    dev_times.append(end_t - start_t)    
cpu_times, dev_times = np.array(cpu_times), np.array(dev_times)
print(f"cpu: inference time: {cpu_times.mean():.5f}s, std: {cpu_times.std():.5f}")
print(f"dev: inference time: {dev_times.mean():.5f}s, std: {dev_times.std():.5f}")

In [None]:
# parameters
RES = 64 # [#] resolution of the grid in pixels (square grid)
L = 1.0 # [m] length of the grid in the r/x direction (square grid)
R0 = 1.5 # [m] grid start in the r/x direction
Z0 = -0.5 # [m] grid start in the z/y direction
R1, Z1 = R0+L, Z0+L # [m] grid ends in the r/x and z/y direction
RM, ZM = 0.5*(R0+R1), 0.5*(Z0+Z1) # [m] grid center in the x/r and z direction
# calculated constants
R = np.linspace(R0, R1, RES)
Z = np.linspace(Z0, Z1, RES)
assert np.isclose(R1-R0, Z1-Z0), "grid must be square"
δ = L/RES # [m] grid spacing
RR, ZZ = np.meshgrid(R, Z) # create a grid of R and Z values
RZ = np.stack((RR, ZZ), axis=-1) # create a grid of R and Z values