# torchsurf ukko AML model

Notebook to develop the torchsurf-ukko model with AML data

Kernals to use:

- Carbon X1:  pytorch
- HUS Dell: 
- CSC: Python 3 (ipykernel)
- ecare4meb2 ML: 

In [None]:
import torch
import torch.nn as nn
import math
import ukko 
import importlib
# For preprocessing
#print("Loading sklearn")
#from sklearn.preprocessing import StandardScaler
#from sklearn_pandas import DataFrameMapper 
import pandas as pd
import numpy as np

from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Our package
from torchsurv.loss.cox import neg_partial_log_likelihood
from torchsurv.loss.weibull import neg_log_likelihood, log_hazard, survival_function
from torchsurv.metrics.brier_score import BrierScore
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc
#from torchsurv.stats.kaplan_meier import KaplanMeierEstimator

print("Libraries loaded")

For interactive plots execute:

In jupyter notebook:

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt

In VSCode, JupyterLab:

- you need to have `ipykernel` installed, eg. from conda-forge 

In [None]:
#conda install conda-forge::ipykernal 

# Import and configure interactive plotting
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt

# Enable interactive mode
%matplotlib widget

# Test widget functionality
def test_interactive():
    slider = widgets.FloatSlider(
        value=0.4,
        min=0.0,
        max=1.0,
        step=0.01,
        description='Test:',
        continuous_update=False
    )
    display(slider)
    return "Widget test complete"

test_interactive()

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

In [None]:
%matplotlib --list

## Load data

In [None]:
#Load tidy data
print("Loading tidy data")
df_xy = pd.read_csv("data/df_xy_synth_v1.csv")

# create train, validation and test datasets: IMPUTE nan: -1
df_train = df_xy.fillna(-1)
df_test = df_train.sample(n=200, random_state=42)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(n=200, random_state=42)
df_train = df_train.drop(df_val.index)

print(f"Train: {df_train.shape}")
print(f"Val  : {df_val.shape}")
print(f"Test : {df_test.shape}")


In [None]:
df_train.head(5)

## Set up for torchsurv

In [None]:
# Detect available accelerator; Downgrade batch size if only CPU available
if any([torch.cuda.is_available(), torch.backends.mps.is_available()]):
    print("CUDA-enabled GPU/TPU is available.")
    BATCH_SIZE = 128  # batch size for training
else:
    print("No CUDA-enabled GPU found, using CPU.")
    BATCH_SIZE = 128# 32  # batch size for training

EPOCHS = 30
LEARNING_RATE = 1e-2

In [None]:
from torch.utils.data import Dataset

class torchsurv_dataset(Dataset):
    """ "Custom dataset for torcsurv use in df format"""

    # defining values in the constructor
    def __init__(self, df: pd.DataFrame):
        self.df = df

    # Getting data size/length
    def __len__(self):
        return len(self.df)

    # Getting the data samples
    def __getitem__(self, idx):
        sample = self.df.iloc[idx]
        # Targets
        event = torch.tensor(sample["OSS_status"]).bool()
        time = torch.tensor(sample["OSS_days"]).float()
        # Predictors
        x = torch.tensor(sample.drop(["person_id", "OSS_status", "OSS_days"]).values).float()
        return x, (event, time)

import matplotlib.pyplot as plt
import pandas as pd

def plot_losses(train_losses, val_losses, title: str = "Cox") -> None:

    train_losses = torch.stack(train_losses) / train_losses[0]
    val_losses = torch.stack(val_losses) / val_losses[0]

    plt.plot(train_losses, label="training")
    plt.plot(val_losses, label="validation")
    plt.legend()
    plt.xlabel("Epochs")
    plt.ylabel("Normalized loss")
    plt.title(title)
    plt.yscale("log")
    plt.show()

## Dataloaders

In [None]:
# Dataloader
dataloader_train = DataLoader(
    torchsurv_dataset(df_train), batch_size=BATCH_SIZE, shuffle=True
)
dataloader_val = DataLoader(
    torchsurv_dataset(df_val), batch_size=len(df_val), shuffle=False
)
dataloader_test = DataLoader(
    torchsurv_dataset(df_test), batch_size=len(df_test), shuffle=False
)

In [None]:
# Sanity check
x, (event, time) = next(iter(dataloader_train))
num_features = x.size(1)

print(f"x (shape)    = {x.shape}")
print(f"num_features = {num_features}")
print(f"event        = {event.shape}")
print(f"time         = {time.shape}")

## Artificial testing dataset

In [None]:
importlib.reload(ukko.data)

#Note: Move this class to ukko.data later
class SineWaveDatasetSurvival(ukko.data.SineWaveDataset):
    def __init__(self, n_samples, weibull_shape=1, weibull_scale=1, seed=42):
        """
        Creates sine wave dataset with survival times based on first feature's frequency, amplidue, or phase
        
        Args:
            n_samples: Number of samples in dataset
            weibull_shape: Shape parameter (k) for Weibull distribution
            weibull_scale: Scale parameter (lambda) for Weibull distribution
            seed: Random seed
        """
        # Init parent class with fixed parameters
        super().__init__(
            n_samples=n_samples,
            n_features=1,
            sequence_length=10,
            prediction_length=1,
            base_freq=0.1,
            noise_level=0.0,
            seed=seed
        )
        
        # Generate Weibull distributed survival times based on first feature frequency
        np.random.seed(seed)
        base_times = np.random.weibull(weibull_shape, n_samples) * weibull_scale
        
        # Scale times based on first feature's frequency. 
        # Note:
        #   - 1-p Weibull distribution is obtained: X = (-ln(U))^(1/k), where U is uniform [0,1] and k is the shape parameter.
        #   - 2-p Weibull inlcuding scale lambda is then: X = lambda * (-ln(U))^(1/k)
        freq_0 = np.array(self.f1f)  # Get frequencies of first feature, each smaple should havea a different frequency, but does not.
        A_samples = np.array(self.f1A) 
        #self.survival_times = np.multiply(base_times, A_samples)
        #self.survival_times = np.multiply(base_times, 2+self.groundtruth[:,0,1].numpy())
        #self.survival_times = 2+self.groundtruth[:,0,1].numpy()
        # make survival times deterministic:
        self.survival_times = 10*A_samples
        print(A_samples)
        
        # Generate random censoring
        # self.censoring = np.random.binomial(1, 0.3, n_samples)  # 30% censoring
        # Generate censoring (30% censored)
        self.events = np.random.binomial(1, 0.7, n_samples)

    def __getitem__(self, idx):
        x, _ = super().__getitem__(idx)
        event = torch.tensor(self.events[idx]).bool()
        time = torch.tensor(self.survival_times[idx]).float()
        return x, (event, time)

# Example usage:
if __name__ == "__main__":
    # Create dataset
    dataset = SineWaveDatasetSurvival(
        n_samples=500,
        weibull_shape=5.0,
        weibull_scale=10.0
    )
    
    # Get first sample
    x, (censoring, time) = dataset[0]
    
    # Plot example
    fig = plt.figure(figsize=(15, 4))
    
    # Plot features
    plt.subplot(121)
    for f in range(dataset.n_features):
        plt.plot(x[f], label=f'Feature {f}')
    plt.title('Features')
    plt.legend()
    
    # Plot survival time distribution
    plt.subplot(122)
    plt.hist(dataset.survival_times, bins=20)
    plt.title('Survival Time Distribution')
    plt.xlabel('Time')
    plt.ylabel('Count')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Create artifical datasets for testing
train_dataset = SineWaveDatasetSurvival(n_samples=500, seed=42)
val_dataset = SineWaveDatasetSurvival(n_samples=500, seed=43)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Verify the data
print("Training samples:", len(train_dataset))
print("Validation samples:", len(val_dataset))

# Check first batch
x, (event, time) = next(iter(train_loader))
print("\nFirst batch shapes:")
print(f"Features shape: {x.shape}")
print(f"Events shape: {event.shape}")
print(f"Times shape: {time.shape}")

In [None]:
train_dataset.groundtruth[0,0,:]

## Section 1: Classical loghazard model

### 1.1 Define model

In [None]:
cox_model = torch.nn.Sequential(
    torch.nn.BatchNorm1d(num_features),  # Batch normalization
    torch.nn.Linear(num_features, 32),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(32, 64),
    torch.nn.ReLU(),
    torch.nn.Dropout(),
    torch.nn.Linear(64, 1),  # Estimating log hazards for Cox models
)
cox_model

### 1.2 Train model

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
torch.manual_seed(42)

# Init optimizer for Cox
optimizer = torch.optim.Adam(cox_model.parameters(), lr=LEARNING_RATE)

# Initiate empty list to store the loss on the train and validation sets
train_losses = []
val_losses = []

warnings.filterwarnings("ignore")

# training loop
for epoch in range(EPOCHS):
    epoch_loss = torch.tensor(0.0)
    for i, batch in enumerate(dataloader_train):
        x, (event, time) = batch
        optimizer.zero_grad()
        log_hz = cox_model(x)  # shape = (16, 1)
        loss = neg_partial_log_likelihood(log_hz, event, time, reduction="mean")
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach()

    if epoch % (EPOCHS // 10) == 0:
        print(f"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}")

    # Reccord loss on train and test sets
    train_losses.append(epoch_loss)
    with torch.no_grad():
        x, (event, time) = next(iter(dataloader_val))
        epoch_val_loss = neg_partial_log_likelihood(cox_model(x), event, time, reduction="mean")
        val_losses.append(
          epoch_val_loss
        )

    if epoch % (EPOCHS // 10) == 0:
        print(f"Epoch: {epoch:03}, Training loss: {epoch_loss:0.2f}. Validation loss: {epoch_val_loss:0.2f}")

    epoch_loss /= i + 1
    

warnings.filterwarnings("default")

In [None]:
plot_losses(train_losses, val_losses, "Cox")

## Section 2: ukko

### 2.1 Prepare data

In [None]:
importlib.reload(ukko)
importlib.reload(ukko.utils)


In [None]:
## Test:

# Convert AML data to multiindex df
df_x, data_3d = ukko.utils.convert_to_3d_df(df_train.iloc[:,3:].fillna(-1))
df_y = df_train.iloc[:,:3]
#display(df_x)
#display(df_y)

In [None]:
data_3d.shape

In [None]:
idx = [0, 1, 3]
torch.tensor(data_3d[idx,:,:]).shape

In [None]:
class ukkosurv_dataset(Dataset):
    """ "Custom dataset for ukko-torcsurv use in df format"""

    # defining values in the constructor
    def __init__(self, df: pd.DataFrame):
        #self.df = df
        df_x, data_3d = ukko.utils.convert_to_3d_df(df.iloc[:,3:].fillna(-1))
        df_y = df_train.iloc[:,:3]
        
        self.df_y = df_y        # Dataframe with survival data, e.g. OSS_status, OSS_days
        self.data_3d = data_3d  # numpy array with 3D feature data: patients, features, time 


    # Getting data size/length
    def __len__(self):
        return len(self.data_3d)

    # Getting the data samples
    def __getitem__(self, idx):
        y = self.df_y.iloc[idx,:]
        # Targets
        event = torch.tensor(y["OSS_status"]).bool()
        time = torch.tensor(y["OSS_days"]).float()
        # Predictors
        x = torch.tensor(self.data_3d[idx,:,:]).float()
        return x, (event, time)

In [None]:
# Dataloader
BATCH_SIZE = 512
dataloader_train = DataLoader(
    ukkosurv_dataset(df_train), batch_size=BATCH_SIZE, shuffle=True
)
dataloader_val = DataLoader(
    ukkosurv_dataset(df_val), batch_size=len(df_val), shuffle=False
)
dataloader_test = DataLoader(
    ukkosurv_dataset(df_test), batch_size=len(df_test), shuffle=False
)

In [None]:
help(dataloader_train)

In [None]:
# Sanity check
x, (event, time) = next(iter(dataloader_train))
num_features, num_timepoints = x.size(1), x.size(2)

print(f"x (shape)      = {x.shape}")
print(f"num_features   = {num_features}")
print(f"num_timepoints = {num_timepoints}")
print(f"event          = {event.shape}")
print(f"time           = {time.shape}")

### 2.1 Artifical dataset for testing

In [None]:
# Create artifical datasets for testing
train_dataset = SineWaveDatasetSurvival(n_samples=500, seed=42)
val_dataset = SineWaveDatasetSurvival(n_samples=500, seed=43)

# Asing paramters for model accordingly:
num_features   = train_dataset.n_features
num_timepoints = train_dataset.sequence_length

# Create dataloaders
batch_size = 500
dataloader_train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dataloader_val = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Verify the data
print("Training samples:", len(train_dataset))
print("Validation samples:", len(val_dataset))

# Check first batch
x, (event, time) = next(iter(dataloader_train))
print("\nFirst batch shapes:")
print(f"Features shape: {x.shape}")
print(f"Events shape: {event.shape}")
print(f"Times shape: {time.shape}")

### 2.2 Define model

In [None]:
importlib.reload(ukko.core)

# Initialize model
# DualAttentionRegressor1(self, n_features, time_steps, d_model=128, n_heads=8, dropout=0.1, n_modules=1)
model = ukko.core.DualAttentionRegressor1(
    n_features=num_features,
    time_steps=num_timepoints,
    d_model=8,
    n_heads=4,
    dropout=0.2,
    n_modules=1,
    n_kv_heads = 4
)
model.to(device)

# Now, let's re-check a specific weight from input_projection
# You can access it through the modules_list
if isinstance(model.modules_list[0], ukko.core.DualAttentionModule):
    print(f"input_projection weights device AFTER model.to(device): {model.modules_list[0].input_projection.weight.device}")




In [None]:
# Sanity check:

def sanity_check(model, dataloader):
    """
    Perform basic model sanity checks
    """
    # Get a single batch
    batch, (event, time) = next(iter(dataloader))
    
    print("Input shapes:")
    print(f"Batch: {batch.shape}")
    print(f"Event: {event.shape}")
    print(f"Time: {time.shape}")
    
    # Run forward pass
    try:
        predictions, feat_attn, time_attn = model(batch)
        print("\nOutput shapes:")
        print(f"Predictions: {predictions.shape}")
        print(f"Feature attention: {feat_attn.shape}")
        print(f"Time attention: {time_attn.shape}")
        
        print("\nValue ranges:")
        print(f"Predictions min/max: {predictions.min():.3f}/{predictions.max():.3f}")
        print(f"Feature attention min/max: {feat_attn.min():.3f}/{feat_attn.max():.3f}")
        print(f"Time attention min/max: {time_attn.min():.3f}/{time_attn.max():.3f}")
        
        return True
    except Exception as e:
        print(f"Error during forward pass: {str(e)}")
        return False

# Run sanity check
success = sanity_check(model, dataloader_train)

### 2.3 Model training

In [None]:
def analyze_model_parameters(model):
    """Analyze model parameters and their shapes"""
    total_params = 0
    print("Model Parameter Analysis:")
    print("-" * 80)
    print(f"{'Layer':<50} {'Shape':<20} {'Parameters':<10}")
    print("-" * 80)
    
    for name, param in model.named_parameters():
        num_params = param.numel()
        total_params += num_params
        print(f"{name:<50} {str(list(param.shape)):<20} {num_params:<10,d}")
    
    print("-" * 80)
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Run the analysis
analyze_model_parameters(model)

In [None]:
torch.manual_seed(42)

EPOCHS = 20
LEARNING_RATE = 1e-3

# Init optimizer for Cox
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Initiate empty list to store the loss on the train and validation sets
train_losses = []
val_losses = []

# Get device and move model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
model = model.to(device)

# training loop
for epoch in range(EPOCHS):
    epoch_loss = torch.tensor(0.0)
    for i, batch in enumerate(dataloader_train):
        x, (event, time) = batch
        x = x.to(device)
        optimizer.zero_grad()
        log_hz, feature_weights, time_weights = model(x)  # shape = (batchsize, 1)
        loss = neg_partial_log_likelihood(log_hz, event.to(device), time.to(device), reduction="mean")
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().to("cpu")

    # Reccord loss on train and test sets
    epoch_loss /= i + 1
    train_losses.append(epoch_loss)
    model.eval()
    with torch.no_grad():
        x, (event, time) = next(iter(dataloader_val))
        x = x.to(device)
        log_hz, feature_weights, time_weights = model(x)
        val_loss = neg_partial_log_likelihood(log_hz, event, time, reduction="mean")
        val_losses.append(
            val_loss.detach().to("cpu")
        )

    # Display progress
    #if epoch % (EPOCHS // 10) == 0:
    print(f"Epoch: {epoch:03}, Training loss: {train_losses[-1]:0.2f}, Validation loss: {val_losses[-1]:0.2f}")


In [None]:
model.eval()
with torch.no_grad():
    # test event and test time of length n
    x, (event, time) = next(iter(dataloader_train))
    log_hz, feature_weights, time_weights = model(x)  # log hazard of length n

# Concordance index
cox_cindex = ConcordanceIndex()
print("Cox model performance:")
print(f"Concordance-index   = {cox_cindex(log_hz, event, time)}")
print(f"Confidence interval = {cox_cindex.confidence_interval()}")

# plot log hazard vs time
plt.figure(figsize=(10, 5))
# change to scatter plot and color by event 
plt.scatter(time, log_hz, c=event, cmap='coolwarm', alpha=0.5)
plt.xlabel("Surival Time")
plt.ylabel("Log hazard")



In [None]:
model.parameters()

### 2.4 Weibull-ukko 

This means 2 outputs of model: shape and scale. 

In [None]:
# Create Weibull model

importlib.reload(ukko.core)

# Initialize model
# DualAttentionRegressor1(self, n_features, time_steps, d_model=128, n_heads=8, dropout=0.1, n_modules=1)
model = ukko.core.DualAttentionRegressor1(
    n_features=num_features,
    time_steps=num_timepoints,
    d_model=16,
    n_heads=4,
    dropout=0.2,
    n_modules=2,
    n_outputs=2 # output for log shape and log scale of Weibull
)
model

In [None]:
# Training of Weibull model

torch.manual_seed(42)

EPOCHS = 60
LEARNING_RATE = 1e-2

# Init optimizer for Cox
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Initiate empty list to store the loss on the train and validation sets
train_losses = []
val_losses = []

# Get device and move model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# training loop
for epoch in range(EPOCHS):
    epoch_loss = torch.tensor(0.0)
    for i, batch in enumerate(dataloader_train):
        x, (event, time) = batch
        x.to(device)
        optimizer.zero_grad()
        log_params, feature_weights, time_weights = model(x)  
        loss = neg_log_likelihood(log_params, event, time, reduction="mean")
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach()

    # Reccord loss on train and test sets
    epoch_loss /= i + 1
    train_losses.append(epoch_loss)
    with torch.no_grad():
        x, (event, time) = next(iter(dataloader_val))
        log_params, feature_weights, time_weights = model(x)
        val_losses.append(
            neg_log_likelihood(log_params, event, time, reduction="mean")
        )

    # Display progress
    #if epoch % (EPOCHS // 10) == 0:
    print(f"Epoch: {epoch:03}, Training loss: {train_losses[-1]:0.2f}, Validation loss: {val_losses[-1]:0.2f}")

In [None]:
model.eval()
with torch.no_grad():
    # test event and test time of length n
    x, (event, time) = next(iter(dataloader_train))
    log_params, feature_weights, time_weights = model(x)  # log hazard of length n

# Additional step for Weibull:
# Compute the log hazards from weibull log parameters
log_hz = log_hazard(log_params, time)  
# Compute the survival probability from weibull log parameters
surv = survival_function(log_params, time)  

# Concordance index
weibull_cindex = ConcordanceIndex()
print("Weibull model performance:")
print(f"Concordance-index   = {weibull_cindex(log_hz, event, time)}")
print(f"Confidence interval = {weibull_cindex.confidence_interval()}")

# H0: cindex = 0.5, Ha: cindex >0.5
print(f"p-value             = {weibull_cindex.p_value(alternative = 'greater')}")

# plot log hazard vs time
plt.figure(figsize=(10, 5))
# change to scatter plot and color by event 
plt.scatter(time, log_params[:,0], c=event, cmap='coolwarm', alpha=0.5)
plt.xlabel("Surival Time")
plt.ylabel("Shape parameter (log scale)")

plt.figure(figsize=(10, 5))
# change to scatter plot and color by event 
plt.scatter(time, log_params[:,1], c=event, cmap='coolwarm', alpha=0.5)
plt.xlabel("Surival Time")
plt.ylabel("Scale parameter (log scale)")

plt.figure(figsize=(10, 5))
# change to scatter plot and color by event 
plt.scatter(log_params[:,0], log_params[:,1], c=event, cmap='coolwarm', alpha=0.5)
plt.xlabel("Shape parameter (log scale)")
plt.ylabel("Scale parameter (log scale)")



In [None]:
log_hz.shape

# Honey

In [None]:

y = [0.0011, 0.0026, 0.0062, 0.0173]
t = [19.5, 20.16, 20.75, 22.5]



In [None]:
def time_str_to_hours(time_str):
    """Convert time string in format 'XhYY' to decimal hours
    
    Args:
        time_str (str): Time in format like '2h13' or '3h45'
    
    Returns:
        float: Decimal hours
    """
    h, m = time_str.split('h')
    return float(h) + float(m)/60

y = np.array([0.0011, 0.0026, 0.0062, 0.0173, 0.0315, 0.34])
h = ['19h31', '20h10', '20h45', '22h25', '23h00', '25h38']

#y = np.array([0.0026, 0.0062, 0.0173, 0.0315])
#h = ['20h10', '20h45', '22h25', '23h00']


# Convert time strings to decimal hours
t = [time_str_to_hours(time) for time in h]
print("Times in decimal hours:", [f"{x:.2f}" for x in t])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

# Define exponential function
def exp_func(x, a, b):
    return a * np.exp(b * x)

# Convert lists to numpy arrays
#t = np.array([19.5, 20.16, 20.75, 22.45])
#y = np.array([0.0011, 0.0026, 0.0062, 0.0175])

# Fit exponential function
popt, pcov = curve_fit(exp_func, t, y)
a, b = popt

# Generate points for smooth curve
t_smooth = np.linspace(min(t), 27, 100)
y_fit = exp_func(t_smooth, a, b)

# Find where y reaches 0.4
t_at_04 = np.log(0.4/a)/b

def hours_to_hm(hours):
    h = int(hours)
    m = int((hours - h) * 60)
    return f"{h}h {m}min"

# Add this line after finding t_at_04
time_hm = hours_to_hm(t_at_04 - 24)
print(f"Time to reach 0.4: {time_hm}")

# Create plot
plt.figure(figsize=(8, 6))
plt.scatter(t, y, color='blue', label='Data points')
plt.plot(t_smooth, y_fit, 'r-', label=f'Fit: {a:.2e}*exp({b:.2f}x)')

plt.axvline(x=t_at_04, color='green', linestyle='--', label=f't = {time_hm}')

plt.yscale('log')  

plt.xlabel('Time')
plt.ylabel('Value')
plt.title('Exponential Fit')
plt.legend()
plt.grid(True)
plt.show()

print(f"Fitted function: y = {a:.2e} * exp({b:.2f} * x)")