<a href="https://colab.research.google.com/github/linhtrinh213/Interpretable_VAE_TF_regulons/blob/main/Stochastic_Weight_Averaging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Stochastic Weight Averaging
- Implementing SWA on our Variational Autoencoder by wrapping Adam optimizer using SWA class, and then train model. After training set the weights of the model to the SWA averages.

- SWA works by averaging model weights collected during training with stochastic gradient descent (SGD), which typically converges near the edges of low-loss regions. These edge solutions often generalize poorly to test data. In contrast, the averaging process in SWA tends to produce solutions located at the center of wide, flat regions in the loss landscape, which are known to generalize better. During the final 25% of training, the learning rate is increased to encourage exploration of the low-loss region before averaging begins. More on the method on: https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/

### 1: Loading libraries

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#### Setup ####
# install and import required packages
!pip install scanpy
#!pip install decoupler
#!pip install omnipath

import torch; torch.manual_seed(100)
import torch.nn as nn
import torch.utils
import torch.distributions
import torchvision
from torchvision import datasets, transforms
import math
import numpy as np
np.random.seed(100)
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
import scanpy as sc
from collections import OrderedDict
from collections import Counter
import pandas as pd

# select the right device, depending on whether your Colab runs on GPU or CPU
### IMPORTANT: we recommend to change your runtime to GPU, otherwise the training takes much longer
device = 'mps'


Collecting scanpy
  Downloading scanpy-1.11.1-py3-none-any.whl.metadata (9.9 kB)
Collecting anndata>=0.8 (from scanpy)
  Downloading anndata-0.11.4-py3-none-any.whl.metadata (9.3 kB)
Collecting legacy-api-wrap>=1.4 (from scanpy)
  Downloading legacy_api_wrap-1.4.1-py3-none-any.whl.metadata (2.1 kB)
Collecting scikit-learn<1.6.0,>=1.1 (from scanpy)
  Downloading scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting session-info2 (from scanpy)
  Downloading session_info2-0.1.2-py3-none-any.whl.metadata (2.5 kB)
Collecting array-api-compat!=1.5,>1.4 (from anndata>=0.8->scanpy)
  Downloading array_api_compat-1.12.0-py3-none-any.whl.metadata (2.5 kB)
Downloading scanpy-1.11.1-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m42.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading anndata-0.11.4-py3-none-any.whl (144 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.

### 2: Load data

In [3]:
# Load data and use Scanpy to convert it into AnnData
PBMC_train = sc.read_h5ad("/content/drive/MyDrive/WORK/Turing Project/Interpretable_VAE/data/PBMC_train.h5ad")
regulons = pd.read_csv('/content/drive/MyDrive/WORK/Turing Project/Interpretable_VAE/data/regulons.csv')

In [4]:
# devide data into control and stimulated

# Subset for a specific condition, e.g., "control"
PBMC_control = PBMC_train[PBMC_train.obs["condition"] == "control"].copy()

# Another example for "treated"
PBMC_stimulated = PBMC_train[PBMC_train.obs["condition"] == "stimulated"].copy()


### 3: Define model architecture

In [5]:
# Define Encoder:
class Encoder(nn.Module):
    def __init__(self, latent_dims, input_dims, dropout, z_dropout): #dropout between the dense layers, z_dropout define the dropout rates between the encoder/latent space
        super(Encoder, self).__init__() #run the initialize code from nn.Module -> this class behaves like a Pytorch model
        self.encoder = nn.Sequential(
                                     nn.Linear(input_dims, 800),
                                     nn.ReLU(),
                                     nn.Dropout(p = dropout),
                                     nn.Linear(800, 800),
                                     nn.ReLU(),
                                     nn.Dropout(p = dropout))  #two layer, fully connected encoder with dropout

        # outputs mean vector u
        self.mu = nn.Sequential(nn.Linear(800, latent_dims), # the 800 neurons in the second layers -> latent space
                                nn.Dropout(p = z_dropout))
        # outputs standard variance
        self.sigma = nn.Sequential(nn.Linear(800, latent_dims),
                                   nn.Dropout(p = z_dropout))

        self.N = torch.distributions.Normal(0, 1)  # define Gaussian distribution for each input
        self.N.loc = self.N.loc.to(device) # move to the right device
        self.N.scale = self.N.scale.to(device)
        self.kl = 0 # place holder for storing KL divergence (regularization term)
        # KL measures how far the learned Gaussian is from the standard normal (0,1) -> this is a regularization term in VAE
    def forward(self, x):
        x = self.encoder(x) # pass the data to the encoder
        mu =  self.mu(x) # predict mean vector
        sigma = torch.exp(self.sigma(x)) # predict standard var exp for numeric stability
        z = mu + sigma*self.N.sample(mu.shape)  # Sample z using reparameterization trick

        self.kl = (0.5*sigma**2 + 0.5*mu**2 - torch.log(sigma) - 1/2).sum() #calculation of kullback-leibler divergence

        return z # output is the sampled latent vector


In [6]:
import pandas as pd
import numpy as np

def create_mask(adata, regulons, add_nodes:int=10, sep = "\t"):
    """
    Initialize mask M that specifies which latent nodes connect to which decoder nodes.
    Args:
        adata (Anndata): Scanpy single-cell object, we will store the computed mask and the names of the biological processes there
        regulons: which TFs affected which genes
        add_nodes (int): Additional latent nodes for capturing additional variance
    Return:
        adata (Anndata): Scanpy single-cell object that now stores the computed mask and the names of biological processes (in the .uns["_vega"] attribute)
        mask (array): mask M that specifies whether a gene is included in the gene set of a pathway (value one) or not (value zero)
    """

    # Create the mask
    # 1. Get unique genes (targets) and TFs (sources)
    genes = regulons['target'].unique()
    tfs = regulons['source'].unique()

    # get their names and the corresponding sub‐mask
    selected_tfs  = ["STAT1", "STAT2", "STAT3","STAT4","STAT5A","STAT5B","STAT6","IRF1","IRF2","IRF3","IRF4","IRF5","IRF6","IRF7","IRF8","IRF9","NFKB","AP1","MYC","TP53"]

    # 2. Initialize matrix M with zeros
    M = pd.DataFrame(0, index=genes, columns=tfs)

    # 3. Set M[i,j] = 1 where the gene i is affected by TF j
    for _, row in regulons.iterrows(): # for each row in regulons
        M.loc[row['target'], row['source']] = 1 #the corresponding genes, TF box = 1

    M = M.loc[:, selected_tfs]  # if M is a pandas DataFrame

    # Add unannotated nodes
    vec = np.ones((M.shape[0], add_nodes))
    M = np.hstack((M, vec))

    adata.uns['_vega'] = dict() #create attribute "_vega" to store the mask and pathway information
    adata.uns['_vega']['mask'] = M
    adata.uns['_vega']['TFs'] = list(tfs) + ['UNANNOTATED_'+str(k) for k in range(add_nodes)]

    return adata, M

In [7]:
# apply the create_mask function
PBMC_control, mask_ctr = create_mask(PBMC_control,regulons , add_nodes=1)
PBMC_stimulated, mask_sti  = create_mask(PBMC_stimulated,regulons , add_nodes=1)

In [8]:
#---# filter the genes in mask
# define names to filter genes for PBMC in the next chunk
mask_ctr_df = pd.DataFrame(mask_ctr, index= regulons.target.unique())
# list of genes in the regulons list and in our pbmc data
genes_mask = np.array(regulons.target.unique()) # genes in regulons list
pbmc_genes = np.array(PBMC_train.var_names) # genes in OUR data
# Create boolean mask of which genes are in PBMC_train
keep = np.isin(genes_mask, pbmc_genes)
# Apply the filter
filtered_mask_ctr_df = mask_ctr_df.loc[keep, :]
filtered_mask_sti_df = mask_ctr_df.loc[keep, :]

# count 0 (the rows in which all genes are 0 are filtered out) (careful: have to account for the last node that are fully connected)
# Count non-zero elements per row
non_zero_count = np.count_nonzero(filtered_mask_ctr_df, axis=1)

# Filter: keep rows with at least 2 non-zero elements
filtered_mask_ctr_df = filtered_mask_ctr_df[non_zero_count >= 2]

In [9]:
# filter the genes in PBMC data
PBMC_control_filtered = PBMC_control[:, filtered_mask_ctr_df.index].copy()
PBMC_stimulated_filtered = PBMC_stimulated[:, filtered_mask_sti_df.index].copy()


In [10]:
# convert pandas back to numpy array to use for downstream steps
filtered_mask_ctr = filtered_mask_ctr_df.to_numpy()
filtered_mask_sti = filtered_mask_sti_df.to_numpy()


In [11]:
# define VEGA's decoder
class DecoderVEGA(nn.Module):
  """
  Define VEGA's decoder (sparse, one-layer, linear, positive)
  """
  def __init__(self,mask):
        super(DecoderVEGA, self).__init__()

        self.sparse_layer = nn.Sequential(SparseLayer(mask)) # we define the architecture of the decoder below with the class "SparseLayer"
        # This decoder only has 1 layer (Sparse)!!!

  def forward(self, x):
    z = self.sparse_layer(x.to(device))
    return(z)

# define a class SparseLayer, that specifies the decoder architecture (sparse connections based on the mask)
class SparseLayer(nn.Module):
  def __init__(self, mask):
        """
        Extended torch.nn module which mask connection
        """
        super(SparseLayer, self).__init__()

        self.mask = nn.Parameter(torch.tensor(mask, dtype=torch.float).t(), requires_grad=False)
        self.weight = nn.Parameter(torch.Tensor(mask.shape[1], mask.shape[0]))
        self.bias = nn.Parameter(torch.Tensor(mask.shape[1]))
        self.reset_parameters()

        # mask weight
        self.weight.data = self.weight.data * self.mask

  def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
        # See the autograd section for explanation of what happens here
        return SparseLayerFunction.apply(input, self.weight, self.bias, self.mask)
        # OUTPUT of the decoder



# defines a custom forward and backward pass
class SparseLayerFunction(torch.autograd.Function):
    """
    We define our own autograd function which masks it's weights by 'mask'.
    For more details, see https://pytorch.org/docs/stable/notes/extending.html
    """

    # Note that both forward and backward are @staticmethods
    @staticmethod
    def forward(ctx, input, weight, bias, mask):
        # enforce the forward connection between latent and next layer to be sparse
        weight = weight * mask # change weight to 0 where mask == 0
        #calculate the output
        output = input.mm(weight.t()) # output = input × weight.T  (torch.mm : matrix multiplication)
        # IMPORTANT!!!
        # input = latent: has the dim (batch_size, latent_dim)
        #-> mask has the dim: (input_dim, latent_dim) -> transpose (latent_dim, input_dim) (input here this the original input, not the latent vector)
        # output: (batch_size, input_dim)  = same dimension with OUR input
        output += bias.unsqueeze(0).expand_as(output) # Add bias to all values in output
        ctx.save_for_backward(input, weight, bias, mask)
        return output

    @staticmethod
    def backward(ctx, grad_output): # define the gradient formula
        # compute gradient for backpropagation
        input, weight, bias, mask = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = grad_mask = None

        # These needs_input_grad checks are optional and only to improve efficiency
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
            # grad_input: how the loss changes with respect to the input
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input) # how the loss changes with respect to weight
            grad_weight = grad_weight * mask # change grad_weight to 0 where mask == 0  (enforce the mask even in backward pass)
        if ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias, grad_mask


### 4: Training
SWA is applied on the training part

In [12]:
# define class that combine encoder and decoder
class VEGA(nn.Module):
    def __init__(self, latent_dims, input_dims, mask, dropout = 0.3, z_dropout = 0.3):
        super(VEGA, self).__init__()
        self.encoder = Encoder(latent_dims, input_dims, dropout, z_dropout) # we use the same encoder as before (two-layer, fully connected, non-linear)
        self.decoder = DecoderVEGA(mask)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

In [13]:
#training loop
def trainVEGA(vae, data, epochs=30, beta = 0.0001, learning_rate = 0.01):
    opt = torch.optim.Adam(vae.parameters(), lr = learning_rate, weight_decay = 5e-4)
    vae.train() #train mode
    losses = []
    klds = []
    mses = []

    for epoch in range(epochs):
        loss_e = 0
        kld_e = 0
        mse_e = 0

        for x in data:
            x = x.to(device)
            opt.zero_grad()
            x_hat = vae(x)
            mse = ((x - x_hat)**2).sum()
            kld = beta* vae.encoder.kl
            loss = mse +  kld # loss calculation
            loss.backward()
            opt.step()
            loss_e += loss.to('cpu').detach().numpy()
            kld_e += kld.to('cpu').detach().numpy()
            mse_e += mse.to('cpu').detach().numpy()

        losses.append(loss_e/(len(data)*128))
        klds.append(kld_e/(len(data)*128))
        mses.append(mse_e/(len(data)*128))

        print("epoch: ", epoch, " loss: ", loss_e/(len(data)*128))

    return vae, losses, klds, mses

In [14]:
# DataLoader
PBMC_controlX = torch.utils.data.DataLoader(PBMC_control_filtered.X.toarray(), batch_size=128) #set up the training data in the right format
PBMC_stimulatedX = torch.utils.data.DataLoader(PBMC_stimulated.X.toarray(), batch_size=128) #set up the training data in the right format

In [15]:
import torch

# NOTE: In Runtime, Change type (hardware accelerator) -> GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [98]:
# train the model for control, NO SWA
vega_ctr = VEGA(latent_dims= filtered_mask_ctr.shape[1], input_dims = filtered_mask_ctr.shape[0], mask = filtered_mask_ctr.T, z_dropout = 0.1, dropout = 0.2).to(device) # input_dim should be the
# model training
vega_ctr, vega_losses_ctr, vega_klds_ctr, vega_mses_ctr = trainVEGA(vega_ctr, PBMC_controlX,epochs = 40, beta=0.00001, learning_rate=0.01) #takes about 2 mins on GPU # change beta!!!

# gotta change dropout rate as well otherwise its too instable


epoch:  0  loss:  168.66997
epoch:  1  loss:  83.2338
epoch:  2  loss:  77.16853
epoch:  3  loss:  74.27911
epoch:  4  loss:  71.30065
epoch:  5  loss:  68.666885
epoch:  6  loss:  69.787636
epoch:  7  loss:  69.30956
epoch:  8  loss:  66.852806
epoch:  9  loss:  66.3687
epoch:  10  loss:  65.81338
epoch:  11  loss:  65.20503
epoch:  12  loss:  65.14792
epoch:  13  loss:  64.781265
epoch:  14  loss:  64.1656
epoch:  15  loss:  64.299095
epoch:  16  loss:  64.09797
epoch:  17  loss:  63.42421
epoch:  18  loss:  63.734146
epoch:  19  loss:  63.314228
epoch:  20  loss:  63.52697
epoch:  21  loss:  63.820858
epoch:  22  loss:  64.22026
epoch:  23  loss:  63.18825
epoch:  24  loss:  63.215977
epoch:  25  loss:  63.324707
epoch:  26  loss:  62.851784
epoch:  27  loss:  62.930813
epoch:  28  loss:  63.10973
epoch:  29  loss:  63.174633
epoch:  30  loss:  63.82219
epoch:  31  loss:  62.876274
epoch:  32  loss:  62.90719
epoch:  33  loss:  62.73673
epoch:  34  loss:  62.874084
epoch:  35  loss:

In [128]:
# try again : selbst!!: SWA
# loader: PBMC_controlX, opt:  Adam, model: vega_ctr , loss_fn:??
#opt = torch.optim.Adam(vae.parameters(), lr = learning_rate, weight_decay = 5e-4)
# https://docs.pytorch.org/docs/stable/optim.html#weight-averaging-swa-and-ema

from torch.optim.swa_utils import AveragedModel, SWALR, update_bn

#training loop
def trainVEGA_SWA(vae, data, epochs=40, beta = 0.00001, learning_rate = 0.01):
    swa_model = torch.optim.swa_utils.AveragedModel(vae)
    opt = torch.optim.Adam(vae.parameters(), lr = learning_rate, weight_decay = 5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt,
                                     T_max=300)
    vae.train() #train mode
    losses = []
    klds = []
    mses = []
    swa_start = 30
    swa_scheduler = SWALR(opt, swa_lr=0.02) #2x faster

    for epoch in range(epochs):
        loss_e = 0
        kld_e = 0
        mse_e = 0

        for x in data:
            x = x.to(device)
            opt.zero_grad()
            x_hat = vae(x)
            mse = ((x - x_hat)**2).sum()
            kld = beta* vae.encoder.kl
            loss = mse +  kld # loss calculation
            loss.backward()
            opt.step()
            loss_e += loss.to('cpu').detach().numpy()
            kld_e += kld.to('cpu').detach().numpy()
            mse_e += mse.to('cpu').detach().numpy()
        if epoch > swa_start:
          swa_model.update_parameters(vae)
          swa_scheduler.step()
        else:
          scheduler.step()

        losses.append(loss_e/(len(data)*128))
        klds.append(kld_e/(len(data)*128))
        mses.append(mse_e/(len(data)*128))

        print("epoch: ", epoch, " loss: ", loss_e/(len(data)*128))
    torch.optim.swa_utils.update_bn(data, swa_model)
    return swa_model, losses, klds, mses


In [133]:
# train the model for control
vega_ctr = VEGA(latent_dims= filtered_mask_ctr.shape[1], input_dims = filtered_mask_ctr.shape[0], mask = filtered_mask_ctr.T, z_dropout = 0.1, dropout = 0.2).to(device) # input_dim should be the
# model training

vega_ctr_swa, vega_losses_ctr, vega_klds_ctr, vega_mses_ctr = trainVEGA_SWA(vega_ctr, PBMC_controlX,epochs = 40, beta=0.00001, learning_rate=0.01) #takes about 2 mins on GPU # change beta!!!

# gotta change dropout rate as well otherwise its too instable


epoch:  0  loss:  193.76811
epoch:  1  loss:  85.20965
epoch:  2  loss:  76.061386
epoch:  3  loss:  69.54501
epoch:  4  loss:  70.512505
epoch:  5  loss:  71.188576
epoch:  6  loss:  69.70807
epoch:  7  loss:  66.24075
epoch:  8  loss:  63.323578
epoch:  9  loss:  62.016445
epoch:  10  loss:  61.64245
epoch:  11  loss:  61.08014
epoch:  12  loss:  60.6586
epoch:  13  loss:  60.24923
epoch:  14  loss:  60.06552
epoch:  15  loss:  59.793526
epoch:  16  loss:  59.78979
epoch:  17  loss:  59.537487
epoch:  18  loss:  59.71161
epoch:  19  loss:  59.543156
epoch:  20  loss:  59.57012
epoch:  21  loss:  59.840355
epoch:  22  loss:  59.16455
epoch:  23  loss:  59.02792
epoch:  24  loss:  59.926304
epoch:  25  loss:  59.017666
epoch:  26  loss:  58.84915
epoch:  27  loss:  58.797375
epoch:  28  loss:  59.551926
epoch:  29  loss:  59.435497
epoch:  30  loss:  58.87618
epoch:  31  loss:  58.687042
epoch:  32  loss:  58.461895
epoch:  33  loss:  93.326
epoch:  34  loss:  60.10704
epoch:  35  loss

### Evaluate SWA model
- Compare SWA and non-SWA approaches by evaluating the model on test data
- Result: In summary, applying SWA helps reducing MSE in VEGA. However, since VEGA is probabilistic, a systemical approach is to run the training multiple times, take the best model for both cases (with and without SWA), and then compare how they perform on test data.

In [22]:
# download the data
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1zHJKoU8QcQB4cLR-oICO2YY4Nu-QaZHG' -O PBMC_test.h5ad
# load data as anndata object
PBMC_test = sc.read_h5ad("PBMC_test.h5ad")

--2025-05-23 13:46:49--  https://docs.google.com/uc?export=download&id=1zHJKoU8QcQB4cLR-oICO2YY4Nu-QaZHG
Resolving docs.google.com (docs.google.com)... 142.250.141.102, 142.250.141.100, 142.250.141.139, ...
Connecting to docs.google.com (docs.google.com)|142.250.141.102|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://drive.usercontent.google.com/download?id=1zHJKoU8QcQB4cLR-oICO2YY4Nu-QaZHG&export=download [following]
--2025-05-23 13:46:49--  https://drive.usercontent.google.com/download?id=1zHJKoU8QcQB4cLR-oICO2YY4Nu-QaZHG&export=download
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 142.250.141.132, 2607:f8b0:4023:c0b::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|142.250.141.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 45277554 (43M) [application/octet-stream]
Saving to: ‘PBMC_test.h5ad’


2025-05-23 13:46:55 (33.6 MB/s) - ‘PBMC_test.h5ad’ saved [4527

In [23]:
# spilt data control sti
# Subset for a specific condition, e.g., "control"
PBMC_test_ctr = PBMC_test[PBMC_test.obs["condition"] == "control"].copy()

# Another example for "treated"
PBMC_test_sti = PBMC_test[PBMC_test.obs["condition"] == "stimulated"].copy()


In [24]:
PBMC_test_ctr_filtered = PBMC_test_ctr[:, filtered_mask_ctr_df.index].copy()
PBMC_test_sti_filtered = PBMC_test_sti[:, filtered_mask_ctr_df.index].copy()

In [25]:
# test control loader
PBMC_control_testX = torch.utils.data.DataLoader(PBMC_test_ctr_filtered.X.toarray(), batch_size=128) #set up the training data in the right format
PBMC_stimulated_testX = torch.utils.data.DataLoader(PBMC_test_sti_filtered.X.toarray(), batch_size=128) #set up the training data in the right format

In [134]:
# model evaluation

import torch.nn.functional as F

torch.optim.swa_utils.update_bn(PBMC_controlX, vega_ctr_swa)
vega_ctr_swa.eval()

total_mse = 0

with torch.no_grad():
    for x in PBMC_control_testX:
        x = x.to(device)
        x_hat = vega_ctr_swa(x)
        mse = F.mse_loss(x_hat, x, reduction='sum')  # or 'mean'
        total_mse += mse.item()

avg_mse = total_mse / len(PBMC_control_testX.dataset)
print(f"Test MSE: {avg_mse:.4f}")


Test MSE: 61.2531


In [138]:
# Compare to normal model
# model evaluation

import torch.nn.functional as F

vega_ctr.eval()

total_mse = 0

with torch.no_grad():
    for x in PBMC_control_testX:
        x = x.to(device)
        x_hat = vega_ctr(x)
        mse = F.mse_loss(x_hat, x, reduction='sum')  # or 'mean'
        total_mse += mse.item()

avg_mse = total_mse / len(PBMC_control_testX.dataset)
print(f"Test MSE: {avg_mse:.4f}")



Test MSE: 70.3493


In summary, applying SWA helps reducing MSE in VEGA. However, since VEGA is probabilistic, a systemical approach is to run the training multiple times, take the best model for both cases (with and without SWA), and then compare how they perform on test data.