## Training
Prepapre dataset with the prepare_dataset notebook, before running this one.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import scipy.io as sio
from time import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils import calc_gso_batch

device = torch.device("mps") # apple silicon
# device = torch.device("cpu") # cpu
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # nvidia
print(f'device: {device}')

def to_tensor(x, device=torch.device("cpu")): return torch.tensor(x, dtype=torch.float32, device=device)

# Best MSE loss: 0.087

In [2]:
MODEL_SAVE_PATH = "mg_data/mg_planet.pth"
EPOCHS = 30
BATCH_SIZE = 256
LEARNING_RATE = 3e-3
USE_CURRENTS = True
USE_PROFILES = False
USE_MAGNETIC = True
INPUT_SIZE = int(USE_CURRENTS)*14 + int(USE_PROFILES)*202 + int(USE_MAGNETIC)*187
TRAIN_DS_PATH = "data/train_ds.mat"
EVAL_DS_PATH = "data/eval_ds.mat"

- mean current: -10183.76, std current: 34209.11
- mean magnetic: -0.20, std magnetic: 0.58
- mean f_profile: 33.13, std f_profile: 0.28
- mean p_profile: 9654.42, std p_profile: 8788.29

In [3]:
class PlaNetDataset(Dataset):
    def __init__(self, ds_mat_path):
        ds_mat = sio.loadmat(ds_mat_path)
        # output: magnetic flux, transposed (matlab is column-major)
        self.psi = to_tensor(ds_mat["psi"]).view(-1, 1, 64, 64)
        # inputs: radial and vertical position of pixels (for plotting only rn) + currents + measurements + profiles 
        self.rr = to_tensor(ds_mat["rr"]).view(-1,1,64,64) # radial position of pixels (64, 64)
        self.zz = to_tensor(ds_mat["zz"]).view(-1,1,64,64) # vertical position of pixels (64, 64)
        self.currs = ds_mat["currs"] # input currents (n, 14)
        self.magnetic = ds_mat["magnetic"] # input magnetic measurements (n, 187)
        self.f_profile = ds_mat["f_profiles"] # input profiles (n, 101)
        self.p_profile = ds_mat["p_profiles"] # input profiles (n, 101)
        inputs = [] # add the normalized inputs to the list
        if USE_CURRENTS: inputs.append((to_tensor(self.currs)+10183)/34209) # (n, 14)
        if USE_MAGNETIC: inputs.append((to_tensor(self.magnetic)+0.2)/0.58) # (n, 187)
        if USE_PROFILES: inputs.append(torch.cat(((to_tensor(self.f_profile)-33.13)/0.28, 
                                                  (to_tensor(self.p_profile)-9654)/8788), 1)) # (n, 202)
        self.inputs = torch.cat(inputs, 1) # (n, 403)
    def __len__(self): return len(self.psi)
    def __getitem__(self, idx): return self.inputs[idx], self.psi[idx], self.rr[idx], self.zz[idx]

In [None]:
# test dataset
ds = PlaNetDataset(EVAL_DS_PATH)
print(f"Dataset length: {len(ds)}")
print(f"Input shape: {ds[0][0].shape}")
print(f"Output shape: {ds[0][1].shape}")

In [5]:
# class PlaNet(nn.Module): # simple fully connected neural network > weak > loss:33 > converges to a constant
#     def __init__(self):
#         super(PlaNet, self).__init__()
#         self.n = 8
#         self.fc1 = nn.Linear(INPUT_SIZE, self.n)
#         self.fc2 = nn.Linear(self.n, self.n)
#         self.fc3 = nn.Linear(self.n, 64*64)
#     def forward(self, x):
#         x = torch.relu(self.fc1(x))
#         x = torch.relu(self.fc2(x))
#         x = self.fc3(x)
#         x = x.view(-1, 64, 64)
#         return x

In [6]:
# class PlaNet(nn.Module): # transpose convolutional neural network > stronger, fast, but artifacts
#     def __init__(self):
#         super(PlaNet, self).__init__()
#         self.n = n = 4
#         self.fc = nn.Sequential(
#             nn.Linear(INPUT_SIZE, 8*n),
#             nn.ReLU(),
#             nn.Linear(8*n, 16*n),
#             nn.ReLU(),
#         )
#         self.unconv = nn.Sequential(
#             nn.ConvTranspose2d(16*n, 8*n, kernel_size=3),
#             nn.ReLU(),
#             nn.BatchNorm2d(8*n), # batch normalization
#             nn.ConvTranspose2d(8*n, 4*n, kernel_size=3, stride=2),
#             nn.ReLU(),
#             nn.ConvTranspose2d(4*n, 2*n, kernel_size=3, stride=2),
#             nn.ReLU(),
#             nn.ConvTranspose2d(2*n, n, kernel_size=3, stride=2),
#             nn.ReLU(),
#             nn.ConvTranspose2d(n, 1, kernel_size=4, stride=2),
#         )
#     def forward(self, x):
#         x = self.fc(x)
#         x = x.view(-1, 16*self.n, 1, 1)
#         x = self.unconv(x)
#         x = x.view(-1, 64, 64)
#         return x

In [7]:
class PlaNet(nn.Module): # upsample convolutional neural network > slower > but smoother
    def __init__(self):
        super(PlaNet, self).__init__()
        self.interp = 'bilinear' # 'nearest' or 'bilinear'
        self.n = n = 2
        self.fc = nn.Sequential(
            nn.Linear(INPUT_SIZE, 8*n),
            nn.ReLU(),
            nn.Linear(8*n, 16*n),
            nn.ReLU(),
        )
        self.unconv = nn.Sequential(
            nn.Upsample(scale_factor=4, mode=self.interp),
            nn.Conv2d(16*n, 8*n, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(8*n), # batch normalization
            nn.Upsample(scale_factor=2, mode=self.interp),
            nn.Conv2d(8*n, 4*n, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode=self.interp),
            nn.Conv2d(4*n, 2*n, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode=self.interp),
            nn.Conv2d(2*n, n, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode=self.interp),
            nn.Conv2d(n, 1, kernel_size=3, padding=1),
        )
    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 16*self.n, 1, 1)
        x = self.unconv(x)
        # x = x.view(-1, 64, 64)
        return x

In [None]:
# test model
model = PlaNet()
x = torch.randn(1, INPUT_SIZE)
print(f"Input shape: {x.shape}")
y = model(x)
print(f"Output shape: {y.shape}")

## Training

In [None]:
train_ds, val_ds = PlaNetDataset(TRAIN_DS_PATH), PlaNetDataset(EVAL_DS_PATH) # initialize datasets
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) # initialize DataLoader
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)  
model = PlaNet()  # instantiate model
model.to(device) # move model to device
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss() # Mean Squared Error Loss
best_loss = float('inf') # initialize best loss
for epoch in range(EPOCHS): 
    epoch_time = time()
    model.train()
    trainloss, evalloss = [], []
    for input_currs, psi, rr, zz in train_dl:
        input_currs, psi, rr, zz = input_currs.to(device), psi.to(device), rr.to(device), zz.to(device) # move to device
        optimizer.zero_grad()
        psi_pred = model(input_currs)
        print(f"psi shape: {psi.shape}, psi_pred shape: {psi_pred.shape}, rr shape: {rr.shape}, zz shape: {zz.shape}")
        gso, gso_pred = calc_gso_batch(psi, rr, zz, device=device), calc_gso_batch(psi_pred, rr, zz, device=device)
        mse_loss = loss_fn(psi_pred, psi) # mean squared error loss on psi
        gso_loss = loss_fn(gso_pred, gso) # PINN loss on grad shafranov
        loss = mse_loss + gso_loss # total loss
        loss.backward()
        optimizer.step()
        trainloss.append((loss.item(), mse_loss.item(), gso_loss.item()))
    model.eval()
    with torch.no_grad():
        for input_currs, psi in val_dl:
            input_currs, psi = input_currs.to(device), psi.to(device) # move to device
            psi_pred = model(input_currs)
            gso, gso_pred = calc_gso_batch(psi), calc_gso_batch(psi_pred)
            mse_loss = loss_fn(psi_pred, psi)
            gso_loss = loss_fn(gso_pred, gso)
            loss = mse_loss + gso_loss
            evalloss.append((loss.item(), mse_loss.item(), gso_loss.item()))
    ttot_loss, tmse_loss, tgso_loss = map(lambda x: sum(x)/len(x), zip(*trainloss))
    etot_loss, emse_loss, egso_loss = map(lambda x: sum(x)/len(x), zip(*evalloss))
    print(f"Ep {epoch+1}/{EPOCHS}: Train Loss:{ttot_loss:.4f}, mse {tmse_loss:.4f}, gso {tgso_loss:.4f} ||" +
          f"Eval Loss:{etot_loss:.4f}, mse {emse_loss:.4f}, gso {egso_loss:.4f}, t:{time()-epoch_time:.2f}s,", end=" ")
    if etot_loss < best_loss:
        best_loss = etot_loss
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print("new best")
    else: print()

In [None]:
model = PlaNet()
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
model.eval()
ds = PlaNetDataset(EVAL_DS_PATH)
# ds = PlaNetDataset(TRAIN_DS_PATH)
for i in np.random.randint(0, len(ds), 10):  
    input_currs, psi_ds = ds[i]
    psi_pred = model(input_currs.unsqueeze(0))
    psi_pred = psi_pred.detach().numpy().reshape(64, 64)
    psi_ds = psi_ds.detach().numpy().reshape(64, 64)
    fig, axs = plt.subplots(1, 5, figsize=(15, 5))
    ext = [ds.rr.min(), ds.rr.max(), ds.zz.min(), ds.zz.max()]
    rr, zz = ds.rr, ds.zz  # radial and vertical positions of pixels
    bmin, bmax = np.min([psi_ds, psi_pred]), np.max([psi_ds, psi_pred])
    err = np.abs(psi_ds - psi_pred)*100/abs(bmax - bmin)
    # err = np.abs(psi_ds - psi_pred)*100/abs((psi_ds + psi_pred)/2)
    err_mse = (psi_ds - psi_pred)**2

    im0 = axs[0].imshow(psi_ds, extent=ext, vmin=bmin, vmax=bmax)
    axs[0].set_title("Actual")
    axs[0].set_aspect('equal')
    fig.colorbar(im0, ax=axs[0]) 

    im1 = axs[1].imshow(psi_pred, extent=ext, vmin=bmin, vmax=bmax)
    axs[1].set_title("Predicted")
    axs[1].set_aspect('equal')
    fig.colorbar(im1, ax=axs[1])

    im2 = axs[2].imshow(err, extent=ext, vmin=0, vmax=5)
    axs[2].set_title("Error")
    axs[2].set_aspect('equal')
    fig.colorbar(im2, ax=axs[2])

    im3 = axs[3].imshow(err_mse, extent=ext, vmin=0, vmax=0.5)
    axs[3].set_title("MSE")
    axs[3].set_aspect('equal')
    fig.colorbar(im3, ax=axs[3])

    c0 = axs[4].contour(rr, zz, psi_ds, levels=20, cmap='viridis', linestyles='dashed')
    c1 = axs[4].contour(rr, zz, psi_pred, levels=20, cmap='viridis')
    axs[4].set_title("Contours")
    axs[4].set_aspect('equal')

    plt.tight_layout()
    plt.show()