In [3]:
  from google.colab import drive
  drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
!pip install wandb -qU

# Log in to your W&B account
import wandb
import random
import math

wandb.login(key="ec2cf1718868be26a8055412b556d952681ee0b6")


[34m[1mwandb[0m: Currently logged in as: [33mege-erdem3[0m ([33mege-erdem-king-s-college-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [1]:
!git clone https://github.com/egerdem/FM-RIR.git

Cloning into 'FM-RIR'...
remote: Enumerating objects: 36, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 36 (delta 13), reused 34 (delta 11), pack-reused 0 (from 0)[K
Receiving objects: 100% (36/36), 13.17 MiB | 17.26 MiB/s, done.
Resolving deltas: 100% (13/13), done.


In [5]:
# List contents of your Drive
#import os
#os.listdir('/content/drive/My Drive/')

# Load a file from your Drive
#with open('/content/drive/My Drive/FMRIR/fm_utils.py', 'r') as f:
    #content = f.read()

# Save a file to your Drive
#with open('/content/drive/My Drive/new_file.txt', 'w') as f:
    #f.write('This is a new file saved from Colab.')

import sys
sys.path.append('/content/drive/My Drive/FMRIR')
sys.path.append('/content/drive/My Drive')
sys.path.append('/FMRIR')

import torch
from torchvision import transforms
import os
import json
import time
import argparse


from FMRIR.fm_utils import (
    SpectrogramSampler, GaussianConditionalProbabilityPath, LinearAlpha,
    LinearBeta, CFGTrainer, SpecUNet
)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# You can set this path manually or use argparse
resume_from_checkpoint = None  # Example: "/content/drive/MyDrive/FMRIR/SpecUNet_20250730-112150/checkpoints/ckpt_1000.pt"

# --- Configuration ---
config = {
    "data": {
        "data_dir": "ir_fs2000_s1024_m1331_room4.0x6.0x3.0_rt200/",
        "src_splits": {
            "train": [0, 820],
            "valid": [820, 922],
            "test": [922, 1024],
            "all": [0, 1024]}
    },
    "model": {
        "name": "SpecUNet",
        "channels": [32, 64, 128],
        "num_residual_layers": 2,
        "t_embed_dim": 40,
        "y_dim": 6,
        "y_embed_dim": 40
    },
    "training": {
        "num_iterations": 8,
        "batch_size": 250,
        "lr": 1e-3,
        "eta": 0.1
    },
    "experiments_dir": "experiments"
}

start_iteration = 0
if resume_from_checkpoint and os.path.exists(resume_from_checkpoint):
    print(f"Resuming training from checkpoint: {resume_from_checkpoint}")
    checkpoint = torch.load(resume_from_checkpoint, map_location=device)
    start_iteration = checkpoint['iteration']

    # Update config with loaded config if you want to ensure consistency
    # config = checkpoint['config']

    # Initialize wandb with the ID of the run you're resuming
    # run_id = checkpoint['wandb_run_id']
    # wandb.init(project="FM-RIR", id=run_id, resume="must", config=config)

else:
    # --- Experiment Setup ---
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    experiment_name = f"{config['model']['name']}_{timestamp}"
    experiment_dir = os.path.join(config['experiments_dir'], experiment_name)
    os.makedirs(experiment_dir, exist_ok=True)

    # Initialize a new wandb run
    wandb.init(project="FM-RIR", name=experiment_name, config=config)

MODEL_SAVE_PATH = os.path.join(experiment_dir, "model.pt")
CHECKPOINT_DIR = os.path.join(experiment_dir, "checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
CONFIG_SAVE_PATH = os.path.join(experiment_dir, "config.json")

with open(CONFIG_SAVE_PATH, 'w') as f:
    json.dump(config, f, indent=4)
print(f"Experiment setup. Config saved to {CONFIG_SAVE_PATH}")

device is: cuda


In [7]:
# run = wandb.init(
#     # Set the wandb entity where your project will be logged (generally your team name).
#     entity="ege-erdem-king-s-college-london",
#     # Set the wandb project where this run will be logged.
#     project="FM-RIR",
#     # Track hyperparameters and run metadata.
#     config=config
#     )

In [8]:
# --- Data Loading ---
data_cfg = config['data']

# ensures we only calculate normalization stats from data the model will be trained on.
# --- Instantiate Samplers for each split (only ONCE) ---
spec_train_sampler = SpectrogramSampler(
    data_path=data_cfg['data_dir'], mode='train', src_splits=data_cfg['src_splits']
).to(device)

spec_valid_sampler = SpectrogramSampler(
    data_path=data_cfg['data_dir'], mode='valid', src_splits=data_cfg['src_splits']
).to(device)

# --- Calculate stats from the single training sampler instance ---
spec_mean = spec_train_sampler.spectrograms.mean()
spec_std = spec_train_sampler.spectrograms.std()
print(f"\nCalculated Mean: {spec_mean:.4f}, Std: {spec_std:.4f} (from training set)")

# --- Define and apply the transform to the existing samplers ---
transform = transforms.Compose([
    transforms.Normalize((spec_mean,), (spec_std,)),
])
spec_train_sampler.transform = transform
spec_valid_sampler.transform = transform

sample_spec, _ = spec_train_sampler.sample(1)
spec_shape = list(sample_spec.shape[1:])

path = GaussianConditionalProbabilityPath(
    p_data=spec_train_sampler, #was [1, 32, 32], for mnist
    p_simple_shape=spec_shape,
    alpha=LinearAlpha(),
    beta=LinearBeta()
).to(device)


# --- Model and Trainer Initialization ---
model_cfg = config['model']
training_cfg = config['training']

spec_unet = SpecUNet(
    channels=model_cfg['channels'], # Same as MNIST version
    num_residual_layers=model_cfg['num_residual_layers'], # Same as MNIST version
    t_embed_dim=model_cfg['t_embed_dim'], # Same as MNIST version
    y_dim=model_cfg['y_dim'], # new: 6D coordinates for source and microphone positions
    y_embed_dim=model_cfg['y_embed_dim'], # Same as MNIST version
).to(device)

trainer = CFGTrainer(
    path=path,
    model=spec_unet,
    eta=training_cfg['eta'],
    y_dim=model_cfg['y_dim'],
)

if start_iteration > 0:
    spec_unet.load_state_dict(checkpoint['model_state_dict'])
    trainer.y_null.data = checkpoint['y_null'].to(device)

# --- Training ---
print(f"\n--- Starting Training for experiment: {experiment_name} ---")
trainer.train(
    num_iterations=training_cfg['num_iterations'],
    device=device,
    lr=training_cfg['lr'],
    batch_size=training_cfg['batch_size'],
    valid_sampler=spec_valid_sampler,
    save_path=MODEL_SAVE_PATH,
    checkpoint_path=CHECKPOINT_DIR,
    checkpoint_interval=1000,  # Save a checkpoint every 1000 iterations
    start_iteration=start_iteration, # Start from 0 or the loaded iteration
    config=config
)

# --- Finalizing the Run ---
# Log the best model as a wandb Artifact for easy access later
if os.path.exists(MODEL_SAVE_PATH):
    print("Logging best model to W&B Artifacts...")
    best_model_artifact = wandb.Artifact(f"{wandb.run.name}-best-model", type="model")
    best_model_artifact.add_file(MODEL_SAVE_PATH)
    wandb.log_artifact(best_model_artifact)

wandb.finish()
print("Training complete and wandb run finished.")

# --- Save the Model ---
# print(f"Saving model to {MODEL_SAVE_PATH}...")
# torch.save({
#     'model_state_dict': spec_unet.state_dict(),
#     'y_null': trainer.y_null,
#     'config': config # Save config with model for easy reference
# }, MODEL_SAVE_PATH)
# print("Model saved. You can now run inference using the model and config from the experiment directory.")
# print(f"Experiment directory: {experiment_dir}")

Loading pre-processed train data from /content/drive/My Drive/ir_fs2000_s1024_m1331_room4.0x6.0x3.0_rt200/processed_train.pt
Loaded 820.0 * 1331 = 1091420 spectrograms for train set.
Spectrogram tensor shape: torch.Size([1091420, 16, 16])
Coordinate tensor shape: torch.Size([1091420, 6])
Processing valid data from .npz files...


Loading valid NPZ files: 100%|██████████| 102/102 [01:20<00:00,  1.27it/s]


Saved processed valid data to /content/drive/My Drive/ir_fs2000_s1024_m1331_room4.0x6.0x3.0_rt200/processed_valid.pt
Loaded 102.0 * 1331 = 135762 spectrograms for valid set.
Spectrogram tensor shape: torch.Size([135762, 16, 16])
Coordinate tensor shape: torch.Size([135762, 6])

Calculated Mean: -68.6621, Std: 10.4724 (from training set)

--- Starting Training for experiment: SpecUNet_20250730-112150 ---
Training model with size: 4.721 MiB


Epoch: 0.0005, Iter: 1, Loss: 551.353, Val Loss: 502.575:   0%|          | 2/5000 [00:01<1:08:23,  1.22it/s]

** New best validation loss: 490.978. Saving model. **


Epoch: 0.0011, Iter: 4, Loss: 453.389, Val Loss: 467.434:   0%|          | 5/5000 [00:02<24:20,  3.42it/s]

** New best validation loss: 486.211. Saving model. **
** New best validation loss: 467.434. Saving model. **


Epoch: 0.0016, Iter: 6, Loss: 402.786, Val Loss: 461.202:   0%|          | 7/5000 [00:02<16:46,  4.96it/s]

** New best validation loss: 464.131. Saving model. **
** New best validation loss: 461.202. Saving model. **


Epoch: 0.0021, Iter: 8, Loss: 351.145, Val Loss: 435.296:   0%|          | 9/5000 [00:02<13:37,  6.11it/s]

** New best validation loss: 455.567. Saving model. **
** New best validation loss: 435.296. Saving model. **


Epoch: 0.0025, Iter: 10, Loss: 327.858, Val Loss: 420.072:   0%|          | 11/5000 [00:03<12:11,  6.82it/s]

** New best validation loss: 433.566. Saving model. **
** New best validation loss: 420.072. Saving model. **


Epoch: 0.0030, Iter: 12, Loss: 326.232, Val Loss: 390.070:   0%|          | 13/5000 [00:03<11:31,  7.21it/s]

** New best validation loss: 408.438. Saving model. **
** New best validation loss: 390.070. Saving model. **


Epoch: 0.0037, Iter: 15, Loss: 316.104, Val Loss: 373.294:   0%|          | 16/5000 [00:03<11:02,  7.52it/s]

** New best validation loss: 383.167. Saving model. **
** New best validation loss: 373.294. Saving model. **


Epoch: 0.0041, Iter: 17, Loss: 303.186, Val Loss: 355.784:   0%|          | 18/5000 [00:04<10:50,  7.66it/s]

** New best validation loss: 365.365. Saving model. **
** New best validation loss: 355.784. Saving model. **


Epoch: 0.0046, Iter: 19, Loss: 302.855, Val Loss: 333.031:   0%|          | 20/5000 [00:04<10:47,  7.69it/s]

** New best validation loss: 354.874. Saving model. **
** New best validation loss: 333.031. Saving model. **


Epoch: 0.0050, Iter: 21, Loss: 291.798, Val Loss: 329.493:   0%|          | 22/5000 [00:04<10:45,  7.72it/s]

** New best validation loss: 329.970. Saving model. **
** New best validation loss: 329.493. Saving model. **


Epoch: 0.0055, Iter: 23, Loss: 293.612, Val Loss: 323.584:   0%|          | 24/5000 [00:04<10:39,  7.78it/s]

** New best validation loss: 315.727. Saving model. **


Epoch: 0.0060, Iter: 25, Loss: 287.753, Val Loss: 301.537:   1%|          | 26/5000 [00:05<10:45,  7.71it/s]

** New best validation loss: 302.496. Saving model. **
** New best validation loss: 301.537. Saving model. **


Epoch: 0.0064, Iter: 27, Loss: 282.016, Val Loss: 291.392:   1%|          | 28/5000 [00:05<10:46,  7.70it/s]

** New best validation loss: 297.698. Saving model. **
** New best validation loss: 291.392. Saving model. **


Epoch: 0.0069, Iter: 29, Loss: 276.045, Val Loss: 289.653:   1%|          | 30/5000 [00:05<10:32,  7.86it/s]

** New best validation loss: 287.908. Saving model. **


Epoch: 0.0073, Iter: 31, Loss: 278.755, Val Loss: 282.378:   1%|          | 32/5000 [00:05<10:50,  7.64it/s]

** New best validation loss: 282.314. Saving model. **


Epoch: 0.0080, Iter: 34, Loss: 274.847, Val Loss: 277.182:   1%|          | 35/5000 [00:06<10:58,  7.53it/s]

** New best validation loss: 279.910. Saving model. **
** New best validation loss: 277.182. Saving model. **


Epoch: 0.0085, Iter: 36, Loss: 273.116, Val Loss: 273.280:   1%|          | 37/5000 [00:06<10:48,  7.66it/s]

** New best validation loss: 276.558. Saving model. **
** New best validation loss: 273.280. Saving model. **


Epoch: 0.0092, Iter: 39, Loss: 266.737, Val Loss: 270.558:   1%|          | 40/5000 [00:06<10:47,  7.66it/s]

** New best validation loss: 268.726. Saving model. **


Epoch: 0.0096, Iter: 41, Loss: 261.154, Val Loss: 266.077:   1%|          | 42/5000 [00:07<11:37,  7.11it/s]

** New best validation loss: 265.729. Saving model. **


Epoch: 0.0108, Iter: 46, Loss: 264.957, Val Loss: 265.042:   1%|          | 47/5000 [00:07<12:00,  6.87it/s]

** New best validation loss: 262.427. Saving model. **


Epoch: 0.0112, Iter: 48, Loss: 263.674, Val Loss: 262.254:   1%|          | 49/5000 [00:08<12:25,  6.64it/s]

** New best validation loss: 260.966. Saving model. **


Epoch: 0.0117, Iter: 50, Loss: 263.200, Val Loss: 260.172:   1%|          | 51/5000 [00:08<12:24,  6.65it/s]

** New best validation loss: 256.732. Saving model. **


Epoch: 0.0121, Iter: 52, Loss: 258.764, Val Loss: 261.110:   1%|          | 53/5000 [00:08<12:30,  6.59it/s]

** New best validation loss: 252.703. Saving model. **


Epoch: 0.0133, Iter: 57, Loss: 254.257, Val Loss: 249.808:   1%|          | 58/5000 [00:09<13:22,  6.16it/s]

** New best validation loss: 246.106. Saving model. **


Epoch: 0.0137, Iter: 59, Loss: 249.702, Val Loss: 249.290:   1%|          | 60/5000 [00:10<14:27,  5.70it/s]

** New best validation loss: 245.497. Saving model. **


Epoch: 0.0151, Iter: 65, Loss: 247.696, Val Loss: 244.456:   1%|▏         | 66/5000 [00:11<13:46,  5.97it/s]

** New best validation loss: 243.592. Saving model. **


Epoch: 0.0163, Iter: 70, Loss: 246.812, Val Loss: 241.479:   1%|▏         | 71/5000 [00:11<10:56,  7.51it/s]

** New best validation loss: 240.212. Saving model. **


Epoch: 0.0170, Iter: 73, Loss: 240.490, Val Loss: 243.832:   1%|▏         | 74/5000 [00:12<10:24,  7.89it/s]

** New best validation loss: 237.822. Saving model. **


Epoch: 0.0181, Iter: 78, Loss: 240.327, Val Loss: 235.524:   2%|▏         | 79/5000 [00:12<10:37,  7.71it/s]

** New best validation loss: 237.763. Saving model. **
** New best validation loss: 235.524. Saving model. **


Epoch: 0.0190, Iter: 82, Loss: 234.909, Val Loss: 231.603:   2%|▏         | 83/5000 [00:13<10:25,  7.85it/s]

** New best validation loss: 233.613. Saving model. **
** New best validation loss: 231.603. Saving model. **


Epoch: 0.0197, Iter: 85, Loss: 232.684, Val Loss: 231.492:   2%|▏         | 86/5000 [00:13<10:18,  7.95it/s]

** New best validation loss: 228.008. Saving model. **


Epoch: 0.0213, Iter: 92, Loss: 225.859, Val Loss: 229.209:   2%|▏         | 93/5000 [00:14<10:25,  7.84it/s]

** New best validation loss: 226.272. Saving model. **


Epoch: 0.0218, Iter: 94, Loss: 227.447, Val Loss: 221.489:   2%|▏         | 95/5000 [00:14<10:42,  7.64it/s]

** New best validation loss: 225.313. Saving model. **
** New best validation loss: 221.489. Saving model. **


Epoch: 0.0229, Iter: 99, Loss: 225.837, Val Loss: 222.771:   2%|▏         | 100/5000 [00:15<10:14,  7.97it/s]

** New best validation loss: 218.684. Saving model. **


Epoch: 0.0254, Iter: 110, Loss: 223.286, Val Loss: 217.559:   2%|▏         | 111/5000 [00:16<10:21,  7.87it/s]

** New best validation loss: 215.724. Saving model. **


Epoch: 0.0259, Iter: 112, Loss: 222.049, Val Loss: 218.751:   2%|▏         | 113/5000 [00:17<10:19,  7.88it/s]

** New best validation loss: 213.082. Saving model. **


Epoch: 0.0275, Iter: 119, Loss: 218.755, Val Loss: 215.480:   2%|▏         | 120/5000 [00:17<10:16,  7.92it/s]

** New best validation loss: 209.953. Saving model. **


Epoch: 0.0293, Iter: 127, Loss: 209.689, Val Loss: 212.618:   3%|▎         | 128/5000 [00:18<10:25,  7.79it/s]

** New best validation loss: 205.919. Saving model. **


Epoch: 0.0307, Iter: 133, Loss: 213.270, Val Loss: 209.072:   3%|▎         | 134/5000 [00:19<10:24,  7.79it/s]

** New best validation loss: 205.403. Saving model. **


Epoch: 0.0314, Iter: 136, Loss: 213.651, Val Loss: 204.498:   3%|▎         | 137/5000 [00:20<10:29,  7.72it/s]

** New best validation loss: 203.314. Saving model. **


Epoch: 0.0318, Iter: 138, Loss: 197.203, Val Loss: 212.936:   3%|▎         | 139/5000 [00:20<10:28,  7.73it/s]

** New best validation loss: 200.857. Saving model. **


Epoch: 0.0332, Iter: 144, Loss: 202.573, Val Loss: 201.203:   3%|▎         | 145/5000 [00:21<11:43,  6.90it/s]

** New best validation loss: 200.146. Saving model. **


Epoch: 0.0344, Iter: 149, Loss: 198.166, Val Loss: 200.090:   3%|▎         | 150/5000 [00:21<11:51,  6.82it/s]

** New best validation loss: 194.687. Saving model. **


Epoch: 0.0366, Iter: 159, Loss: 194.758, Val Loss: 195.843:   3%|▎         | 160/5000 [00:23<12:02,  6.70it/s]

** New best validation loss: 188.823. Saving model. **


Epoch: 0.0371, Iter: 161, Loss: 194.709, Val Loss: 198.111:   3%|▎         | 162/5000 [00:23<12:07,  6.65it/s]

** New best validation loss: 188.649. Saving model. **


Epoch: 0.0387, Iter: 168, Loss: 198.610, Val Loss: 199.361:   3%|▎         | 169/5000 [00:24<11:15,  7.15it/s]

** New best validation loss: 184.788. Saving model. **


Epoch: 0.0399, Iter: 173, Loss: 187.580, Val Loss: 193.302:   3%|▎         | 174/5000 [00:25<10:21,  7.77it/s]

** New best validation loss: 182.621. Saving model. **


Epoch: 0.0435, Iter: 189, Loss: 180.422, Val Loss: 178.021:   4%|▍         | 190/5000 [00:27<10:19,  7.77it/s]

** New best validation loss: 181.330. Saving model. **
** New best validation loss: 178.021. Saving model. **


Epoch: 0.0447, Iter: 194, Loss: 181.616, Val Loss: 182.601:   4%|▍         | 195/5000 [00:28<10:04,  7.95it/s]

** New best validation loss: 171.364. Saving model. **


Epoch: 0.0476, Iter: 207, Loss: 182.634, Val Loss: 177.695:   4%|▍         | 208/5000 [00:29<10:08,  7.87it/s]

** New best validation loss: 171.055. Saving model. **


Epoch: 0.0504, Iter: 219, Loss: 170.602, Val Loss: 174.138:   4%|▍         | 220/5000 [00:31<10:04,  7.91it/s]

** New best validation loss: 169.903. Saving model. **


Epoch: 0.0509, Iter: 221, Loss: 174.695, Val Loss: 170.725:   4%|▍         | 222/5000 [00:31<10:16,  7.75it/s]

** New best validation loss: 164.939. Saving model. **


Epoch: 0.0538, Iter: 234, Loss: 167.340, Val Loss: 169.729:   5%|▍         | 235/5000 [00:33<10:13,  7.76it/s]

** New best validation loss: 163.640. Saving model. **


Epoch: 0.0550, Iter: 239, Loss: 166.790, Val Loss: 166.518:   5%|▍         | 240/5000 [00:33<10:08,  7.83it/s]

** New best validation loss: 161.508. Saving model. **


Epoch: 0.0589, Iter: 256, Loss: 165.962, Val Loss: 162.960:   5%|▌         | 257/5000 [00:36<11:48,  6.69it/s]

** New best validation loss: 160.648. Saving model. **


Epoch: 0.0593, Iter: 258, Loss: 166.269, Val Loss: 170.862:   5%|▌         | 259/5000 [00:36<12:02,  6.56it/s]

** New best validation loss: 160.400. Saving model. **


Epoch: 0.0602, Iter: 262, Loss: 171.919, Val Loss: 167.456:   5%|▌         | 263/5000 [00:37<11:37,  6.80it/s]

** New best validation loss: 156.707. Saving model. **


Epoch: 0.0607, Iter: 264, Loss: 158.791, Val Loss: 166.425:   5%|▌         | 265/5000 [00:37<11:53,  6.64it/s]

** New best validation loss: 155.746. Saving model. **


Epoch: 0.0621, Iter: 270, Loss: 158.714, Val Loss: 155.163:   5%|▌         | 271/5000 [00:38<11:31,  6.83it/s]

** New best validation loss: 149.669. Saving model. **


Epoch: 0.0680, Iter: 296, Loss: 164.625, Val Loss: 153.921:   6%|▌         | 297/5000 [00:41<10:01,  7.81it/s]

** New best validation loss: 149.590. Saving model. **


Epoch: 0.0694, Iter: 302, Loss: 157.309, Val Loss: 149.170:   6%|▌         | 303/5000 [00:42<10:01,  7.80it/s]

** New best validation loss: 146.858. Saving model. **


Epoch: 0.0701, Iter: 305, Loss: 147.326, Val Loss: 152.650:   6%|▌         | 306/5000 [00:42<10:31,  7.43it/s]

** New best validation loss: 146.012. Saving model. **


Epoch: 0.0715, Iter: 311, Loss: 154.108, Val Loss: 151.397:   6%|▌         | 312/5000 [00:43<10:14,  7.63it/s]

** New best validation loss: 140.528. Saving model. **


Epoch: 0.0880, Iter: 383, Loss: 148.734, Val Loss: 146.556:   8%|▊         | 384/5000 [00:53<10:01,  7.68it/s]

** New best validation loss: 139.351. Saving model. **


Epoch: 0.0891, Iter: 388, Loss: 145.820, Val Loss: 145.896:   8%|▊         | 389/5000 [00:54<09:58,  7.70it/s]

** New best validation loss: 131.069. Saving model. **


Epoch: 0.0996, Iter: 434, Loss: 147.754, Val Loss: 136.750:   9%|▊         | 435/5000 [01:00<09:25,  8.07it/s]

** New best validation loss: 129.420. Saving model. **


Epoch: 0.1045, Iter: 455, Loss: 133.100, Val Loss: 140.709:   9%|▉         | 456/5000 [01:03<11:17,  6.70it/s]

** New best validation loss: 128.962. Saving model. **


Epoch: 0.1104, Iter: 481, Loss: 131.690, Val Loss: 138.407:  10%|▉         | 482/5000 [01:07<09:43,  7.74it/s]

** New best validation loss: 128.062. Saving model. **


Epoch: 0.1214, Iter: 529, Loss: 130.752, Val Loss: 132.415:  11%|█         | 530/5000 [01:13<09:30,  7.84it/s]

** New best validation loss: 127.786. Saving model. **


Epoch: 0.1228, Iter: 535, Loss: 138.070, Val Loss: 131.525:  11%|█         | 536/5000 [01:14<09:34,  7.76it/s]

** New best validation loss: 125.771. Saving model. **


Epoch: 0.1242, Iter: 541, Loss: 127.250, Val Loss: 136.627:  11%|█         | 542/5000 [01:14<09:32,  7.79it/s]

** New best validation loss: 125.661. Saving model. **


Epoch: 0.1329, Iter: 579, Loss: 135.914, Val Loss: 129.844:  12%|█▏        | 580/5000 [01:20<09:36,  7.67it/s]

** New best validation loss: 123.596. Saving model. **


Epoch: 0.1358, Iter: 592, Loss: 127.543, Val Loss: 134.008:  12%|█▏        | 593/5000 [01:22<09:30,  7.73it/s]

** New best validation loss: 123.000. Saving model. **


Epoch: 0.1425, Iter: 621, Loss: 125.140, Val Loss: 127.857:  12%|█▏        | 622/5000 [01:25<09:25,  7.74it/s]

** New best validation loss: 122.029. Saving model. **


Epoch: 0.1457, Iter: 635, Loss: 129.238, Val Loss: 128.487:  13%|█▎        | 636/5000 [01:27<09:18,  7.82it/s]

** New best validation loss: 121.050. Saving model. **


Epoch: 0.1562, Iter: 681, Loss: 128.806, Val Loss: 136.693:  14%|█▎        | 682/5000 [01:34<09:34,  7.52it/s]

** New best validation loss: 117.425. Saving model. **


Epoch: 0.1697, Iter: 740, Loss: 123.551, Val Loss: 125.434:  15%|█▍        | 741/5000 [01:41<09:18,  7.63it/s]

** New best validation loss: 115.276. Saving model. **


Epoch: 0.1860, Iter: 811, Loss: 120.497, Val Loss: 120.194:  16%|█▌        | 812/5000 [01:51<08:47,  7.94it/s]

** New best validation loss: 113.295. Saving model. **


Epoch: 0.2023, Iter: 882, Loss: 120.910, Val Loss: 121.878:  18%|█▊        | 883/5000 [02:01<09:42,  7.07it/s]

** New best validation loss: 112.244. Saving model. **


Epoch: 0.2062, Iter: 899, Loss: 120.042, Val Loss: 124.736:  18%|█▊        | 900/5000 [02:03<08:32,  7.99it/s]

** New best validation loss: 111.110. Saving model. **


Epoch: 0.2160, Iter: 942, Loss: 124.361, Val Loss: 122.163:  19%|█▉        | 943/5000 [02:08<08:40,  7.80it/s]

** New best validation loss: 109.875. Saving model. **


Epoch: 0.2524, Iter: 1101, Loss: 115.888, Val Loss: 115.776:  22%|██▏       | 1102/5000 [02:30<08:25,  7.71it/s]

** New best validation loss: 107.040. Saving model. **


Epoch: 0.2671, Iter: 1165, Loss: 123.763, Val Loss: 112.527:  23%|██▎       | 1166/5000 [02:38<09:15,  6.91it/s]

** New best validation loss: 103.378. Saving model. **


Epoch: 0.3175, Iter: 1385, Loss: 105.729, Val Loss: 101.212:  28%|██▊       | 1386/5000 [03:08<09:41,  6.21it/s]

** New best validation loss: 101.449. Saving model. **
** New best validation loss: 101.212. Saving model. **


Epoch: 0.3523, Iter: 1537, Loss: 109.762, Val Loss: 104.628:  31%|███       | 1538/5000 [03:28<07:21,  7.84it/s]

** New best validation loss: 99.545. Saving model. **


Epoch: 0.3603, Iter: 1572, Loss: 111.368, Val Loss: 107.962:  31%|███▏      | 1573/5000 [03:33<08:03,  7.09it/s]

** New best validation loss: 97.450. Saving model. **


Epoch: 0.3670, Iter: 1601, Loss: 109.260, Val Loss: 107.682:  32%|███▏      | 1602/5000 [03:37<07:29,  7.56it/s]

** New best validation loss: 97.132. Saving model. **


Epoch: 0.3896, Iter: 1700, Loss: 109.845, Val Loss: 105.049:  34%|███▍      | 1701/5000 [03:51<07:21,  7.48it/s]

** New best validation loss: 94.781. Saving model. **


Epoch: 0.4396, Iter: 1918, Loss: 105.189, Val Loss: 101.692:  38%|███▊      | 1919/5000 [04:19<06:48,  7.54it/s]

** New best validation loss: 94.417. Saving model. **


Epoch: 0.4854, Iter: 2118, Loss: 100.655, Val Loss: 100.535:  42%|████▏     | 2119/5000 [04:46<06:08,  7.83it/s]

** New best validation loss: 92.686. Saving model. **


Epoch: 0.5497, Iter: 2399, Loss: 97.087, Val Loss: 105.297:  48%|████▊     | 2400/5000 [05:23<06:24,  6.77it/s]

** New best validation loss: 91.545. Saving model. **


Epoch: 0.5541, Iter: 2418, Loss: 99.540, Val Loss: 96.885:  48%|████▊     | 2419/5000 [05:26<06:26,  6.68it/s]

** New best validation loss: 91.386. Saving model. **


Epoch: 0.5548, Iter: 2421, Loss: 101.913, Val Loss: 102.515:  48%|████▊     | 2422/5000 [05:26<05:52,  7.32it/s]

** New best validation loss: 90.854. Saving model. **


Epoch: 0.5752, Iter: 2510, Loss: 99.280, Val Loss: 97.597:  50%|█████     | 2511/5000 [05:38<06:03,  6.85it/s]

** New best validation loss: 89.935. Saving model. **


Epoch: 0.6137, Iter: 2678, Loss: 103.246, Val Loss: 94.171:  54%|█████▎    | 2679/5000 [05:59<04:47,  8.07it/s]

** New best validation loss: 84.848. Saving model. **


Epoch: 0.7188, Iter: 3137, Loss: 94.136, Val Loss: 96.368:  63%|██████▎   | 3138/5000 [06:59<04:22,  7.10it/s]

** New best validation loss: 83.769. Saving model. **


Epoch: 0.8505, Iter: 3712, Loss: 92.189, Val Loss: 92.516:  74%|███████▍  | 3713/5000 [08:15<02:46,  7.74it/s]

** New best validation loss: 83.706. Saving model. **


Epoch: 0.9868, Iter: 4307, Loss: 88.839, Val Loss: 100.275:  86%|████████▌ | 4308/5000 [09:35<01:32,  7.48it/s]

** New best validation loss: 83.647. Saving model. **


Epoch: 1.0367, Iter: 4525, Loss: 89.607, Val Loss: 97.390:  91%|█████████ | 4526/5000 [10:05<01:01,  7.73it/s]

** New best validation loss: 81.276. Saving model. **


Epoch: 1.0688, Iter: 4665, Loss: 100.085, Val Loss: 89.666:  93%|█████████▎| 4666/5000 [10:23<00:42,  7.81it/s]

** New best validation loss: 80.538. Saving model. **


Epoch: 1.1453, Iter: 4999, Loss: 92.499, Val Loss: 92.131: 100%|██████████| 5000/5000 [11:09<00:00,  7.47it/s]


Saving best model with val loss 80.538 to /content/drive/MyDrive/FMRIR/SpecUNet_20250730-112150/model.pt
Saving model to /content/drive/MyDrive/FMRIR/SpecUNet_20250730-112150/model.pt...
