In [1]:
import numpy as np
import sys
import os
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset, DataLoader
sys.path.append(os.path.abspath('C:/Users/vpming/tuni_ml/src'))
from extract_data import build_cellwise_df
from extract_cell_timetrace import extract_cell_timetrace
import torch.nn as nn

In [3]:
data_path = 'C:/Users/vpming/tuni_ml/data'
df = build_cellwise_df(data_path)

In [4]:
x = np.stack(df['time_trace'])
print(x.shape)
y = df['dis_to_target'].values
print(y.reshape(-1,1))

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

x_scaler = StandardScaler()
x_train_s = x_scaler.fit_transform(x_train)
x_test_s = x_scaler.transform(x_test)

y_scaler = StandardScaler()
y_train_s = y_scaler.fit_transform(y_train.reshape(-1,1)).ravel()
y_test_s = y_scaler.transform(y_test.reshape(-1,1)).ravel()

(2500, 1001)
[[3]
 [3]
 [2]
 ...
 [2]
 [2]
 [3]]


In [5]:
train_ds = TensorDataset(torch.tensor(x_train_s, dtype=torch.float32),
                         torch.tensor(y_train_s, dtype=torch.float32))

test_ds = TensorDataset(torch.tensor(x_test_s, dtype=torch.float32),
                        torch.tensor(y_test_s, dtype=torch.float32))

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64)

In [6]:
class CNNRegressor(nn.Module):
    def __init__(self, n_timepoints):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(1, 16, 7, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(16, 32, 7, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Flatten()
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * (n_timepoints // 4), 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        x = x.unsqueeze(1)  # (batch, 1, timepoints)
        x = self.conv(x)
        return self.fc(x).squeeze(-1)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNNRegressor(x_train.shape[1]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

for epoch in range(20):
    model.train()
    running_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb)
        loss = criterion(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1:2d}: Train loss {running_loss/len(train_loader):.4f}")

# Save the model and scalers
torch.save({
    'model_state_dict': model.state_dict(),
    'scaler_X': x_scaler,
    'scaler_y': y_scaler
}, 'C:/Users/vpming/tuni_ml\src\model/cnn_dtt.pt')


Epoch  1: Train loss 0.7582
Epoch  2: Train loss 0.5816
Epoch  3: Train loss 0.5102
Epoch  4: Train loss 0.4446
Epoch  5: Train loss 0.4342
Epoch  6: Train loss 0.3576
Epoch  7: Train loss 0.3198
Epoch  8: Train loss 0.2859
Epoch  9: Train loss 0.2427
Epoch 10: Train loss 0.1901
Epoch 11: Train loss 0.1701
Epoch 12: Train loss 0.1395
Epoch 13: Train loss 0.1055
Epoch 14: Train loss 0.0863
Epoch 15: Train loss 0.0760
Epoch 16: Train loss 0.0592
Epoch 17: Train loss 0.0477
Epoch 18: Train loss 0.0409
Epoch 19: Train loss 0.0393
Epoch 20: Train loss 0.0357


In [20]:
time_trace_cell_3 = extract_cell_timetrace('C:/Users/vpming/tuni_ml/data/stim_0.5_beta_0.04_noise_0.01_kcross_0.0050/sim_data__stimMag_0.50_beta_0.40_noise_0.010_kcross_0.0050_nSamples_1000_1.h5', 24)
print(time_trace_cell_3)

timetrace_scaled = x_scaler.transform(time_trace_cell_3.reshape(1, -1))
with torch.no_grad():
    pred_scaled = model(torch.tensor(timetrace_scaled, dtype=torch.float32)).numpy()

# Inverse transform to original units
pred_distance = y_scaler.inverse_transform(pred_scaled.reshape(-1, 1)).ravel()[0]
print(f"Predicted distance to target: {pred_distance:.3f}")

[1.00822172 1.01912433 1.03525191 ... 0.99637745 0.99906408 1.00398041]
Predicted distance to target: 2.554
