# Debug Trainer
Mirror of the complex U-Net run from `experiment_runner.ipynb` with hooks to inspect losses.

In [None]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
%env PYTORCH_ENABLE_MPS_FALLBACK=1
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Subset
import sys

PROJECT_ROOT = Path.cwd().resolve()
if PROJECT_ROOT.name == 'notebooks':
    PROJECT_ROOT = PROJECT_ROOT.parent
SRC_ROOT = PROJECT_ROOT / 'src'
if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))

from data.dataset import SingleCoilDataset
from data.masking import EquispacedMasker
from models.real_unet import RealUnet
from models.cx_unet import ComplexUnet
from training.utils import test_loop


env: PYTORCH_ENABLE_MPS_FALLBACK=1


In [2]:

CONFIG = {
    'model': 'complex',  # 'complex' or 'real'
    'train_folder': str(PROJECT_ROOT / 'data' / 'singlecoil_train'),
    'val_folder': str(PROJECT_ROOT / 'data' / 'singlecoil_val'),
    'mask': {'accel': 4, 'acs': 24},
    'train_subset': 1024,
    'val_subset': 256,
    'batch_size': 4,
    'num_workers': 2,
    'epochs': 15,
    'learning_rate': 1e-4,
    'features_real': [32, 64, 128, 256, 512],
    'features_complex': [32, 64, 128, 256, 512],
    'width_scale': 1.416,
    'seed': 1,
    'log_every': 10,
}
CONFIG


{'model': 'complex',
 'train_folder': '/Users/giodegeronimo/Desktop/ECE570/ece570-tinyreproductions/data/singlecoil_train',
 'val_folder': '/Users/giodegeronimo/Desktop/ECE570/ece570-tinyreproductions/data/singlecoil_val',
 'mask': {'accel': 4, 'acs': 24},
 'train_subset': 1024,
 'val_subset': 256,
 'batch_size': 4,
 'num_workers': 2,
 'epochs': 15,
 'learning_rate': 0.0001,
 'features_real': [32, 64, 128, 256, 512],
 'features_complex': [32, 64, 128, 256, 512],
 'width_scale': 1.416,
 'seed': 1,
 'log_every': 10}

In [3]:

masker = EquispacedMasker(accel=CONFIG['mask']['accel'], acs=CONFIG['mask']['acs'])
train_full = SingleCoilDataset(CONFIG['train_folder'], mask_func=masker)
val_full = SingleCoilDataset(CONFIG['val_folder'], mask_func=masker)
train_indices = list(range(min(CONFIG['train_subset'], len(train_full))))
val_indices = list(range(min(CONFIG['val_subset'], len(val_full))))
train_set = Subset(train_full, train_indices)
val_set = Subset(val_full, val_indices)
train_loader = DataLoader(train_set, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'])
val_loader = DataLoader(val_set, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'])
len(train_set), len(val_set)


(1024, 256)

In [4]:

torch.manual_seed(CONFIG['seed'])
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
if CONFIG['model'] == 'complex':
    model = ComplexUnet(in_channels=1, out_channels=1, features=CONFIG['features_complex']).to(device)
else:
    model = RealUnet(in_channels=1, out_channels=1, features=CONFIG['features_real'], width_scale=CONFIG['width_scale']).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
loss_fn = lambda pred, target: (pred - target).abs().mean()


In [6]:
import torch.nn.utils as nn_utils



def train_epoch_debug(model, dataloader, optimizer, loss_fn, device, log_every=None):
    model.train()
    total_loss = 0.0
    for step, (masked, target) in enumerate(dataloader, start=1):
        masked = masked.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        pred = model(masked)

        import matplotlib.pyplot as plt
        from pathlib import Path

        def safe_to_numpy(x):
            return x.detach().cpu().numpy()

        DEBUG_SAVE_DIR = Path("debug_outputs")
        DEBUG_SAVE_DIR.mkdir(exist_ok=True)

        if step % 20 == 0:
            idx = 0  # pick first item in batch
            gt = target[idx].abs()
            pd = pred[idx].abs()
            gt_np = safe_to_numpy(gt.squeeze())
            pd_np = safe_to_numpy(pd.squeeze())

            fig, axes = plt.subplots(1, 2, figsize=(6, 3))
            axes[0].imshow(gt_np, cmap="gray")
            axes[0].set_title("Target")
            axes[0].axis("off")
            axes[1].imshow(pd_np, cmap="gray")
            axes[1].set_title("Prediction")
            axes[1].axis("off")
            fig.tight_layout()

            out_path = DEBUG_SAVE_DIR / f"step{step:04d}.png"
            fig.savefig(out_path, dpi=200)
            plt.close(fig)

        loss = loss_fn(pred, target)
        print(loss.item())
        loss.backward()
        #print(min(param.grad.min().item() for param in model.parameters()), max(param.grad.max().item() for param in model.parameters()))
        #nn_utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        #print(min(param.grad.min().item() for param in model.parameters()), max(param.grad.max().item() for param in model.parameters()))
        optimizer.step()
        total_loss += loss.item()            
        if log_every is not None and (step % log_every == 0 or step == len(dataloader)):
            print(f"Step {step}/{len(dataloader)} loss {loss.item():.4e}")
    return total_loss / max(len(dataloader), 1)


In [7]:
for epoch in range(1, CONFIG['epochs'] + 1):
    train_loss = train_epoch_debug(model, train_loader, optimizer, loss_fn, device, log_every=CONFIG['log_every'])
    val_loss = test_loop(model, val_loader, loss_fn, device)
    print(f"Epoch {epoch}/{CONFIG['epochs']} | train {train_loss:.4e} | val {val_loss:.4e}")


  eigvals, eigvecs = torch.linalg.eigh(V)  # (C, 2), (C, 2, 2)


0.5534026026725769
0.5131915211677551
0.5409350991249084
0.4946202337741852
0.44553518295288086
0.43648549914360046
0.43758314847946167
0.4134536683559418
0.41041579842567444
0.37925222516059875
Step 10/256 loss 3.7925e-01
0.379495769739151
0.3685612380504608
0.3839791715145111
0.3524315059185028
0.3307341933250427
0.34352073073387146
0.35040533542633057
0.3063223958015442
0.33852720260620117
0.31633952260017395
Step 20/256 loss 3.1634e-01
0.3091432750225067
0.309565931558609
0.30736008286476135
0.2835833430290222
0.2834860682487488
0.2694880962371826
0.301587849855423
0.26829975843429565
0.28632667660713196
0.25363367795944214
Step 30/256 loss 2.5363e-01
0.25624194741249084
0.264178991317749
0.2338613122701645
0.29253244400024414
0.22331306338310242
0.24628959596157074
0.25774380564689636
0.2107411026954651
0.23133087158203125
0.23793725669384003
Step 40/256 loss 2.3794e-01
0.24886386096477509
0.24104513227939606
0.23012231290340424
0.2552434206008911
0.258008748292923
0.2502957880496

val:   0%|          | 0/64 [00:00<?, ?it/s]

Epoch 1/15 | train 3.5774e-01 | val 5.8970e-01
0.5723958611488342
0.4838281273841858
0.5285736918449402
0.47784942388534546
0.5463464856147766
0.5428199172019958
0.5108019113540649
0.5383570194244385
0.5637804865837097
0.5937706828117371
Step 10/256 loss 5.9377e-01
0.5514652729034424
0.5421373844146729
0.5517671704292297
0.5477685332298279
0.5033969879150391
0.4834699034690857
0.5044769048690796
0.5057443976402283
0.46205398440361023
0.47638121247291565
Step 20/256 loss 4.7638e-01
0.5393104553222656
0.5735322833061218
0.4287916123867035
0.5033366084098816
0.469308465719223
0.4665353298187256
0.47525084018707275
0.46237269043922424
0.4551950991153717
0.5216749310493469
Step 30/256 loss 5.2167e-01
0.4087998867034912
0.47094568610191345
0.4378039538860321
0.451725035905838
0.4483736455440521
0.4833648204803467
0.5121293663978577
0.4464607238769531
0.4628717005252838
0.43592140078544617
Step 40/256 loss 4.3592e-01
0.4326455593109131
0.524563193321228
0.498033732175827
0.44355064630508423
0

val:   0%|          | 0/64 [00:00<?, ?it/s]

Epoch 2/15 | train 4.2279e-01 | val 3.2773e-01
0.33980730175971985
0.3233066201210022
0.37494105100631714
0.3467027246952057
0.3858273923397064
0.3278050124645233
0.3410329222679138
0.3648868203163147
0.3474820554256439
0.35033321380615234
Step 10/256 loss 3.5033e-01
0.3709387183189392
0.40857401490211487
0.3980324864387512
0.37423452734947205
0.3217097520828247
0.36231693625450134
0.29988712072372437
0.29867926239967346
0.31375664472579956
0.32129424810409546
Step 20/256 loss 3.2129e-01
0.30631133913993835
0.3427363932132721
0.30730316042900085
0.32660186290740967
0.2981935739517212
0.3120017945766449
0.3180062174797058
0.3048381209373474
0.3234547972679138
0.3025783896446228
Step 30/256 loss 3.0258e-01
0.3433068096637726
0.33572763204574585
0.3197738230228424
0.32156506180763245
0.330477237701416
0.3255014717578888
0.30031558871269226
0.2952449321746826
0.30420583486557007
0.30148831009864807
Step 40/256 loss 3.0149e-01
0.2792803645133972
0.31876087188720703
0.27887845039367676
0.281

: 

In [None]:
import torch
from models.cx_unet import ComplexConv2d

# Set up a toy complex conv with zero weights so only biases remain.
conv = ComplexConv2d(in_channels=1, out_channels=1, kernel_size=3, padding=1, bias=True).eval()
with torch.no_grad():
    conv.conv_real.weight.zero_()
    conv.conv_imag.weight.zero_()
    conv.conv_real.bias.fill_(1.0)    # pretend the real bias should be +1
    conv.conv_imag.bias.fill_(0.25)   # pretend the imaginary bias should be +0.25

# Feed an all-zero complex tensor; any non-zero output must come from the bias math.
x = torch.zeros(1, 1, 8, 8, dtype=torch.complex64)
y = conv(x)

print("Output real part (should equal 1.0 if biasing were correct):", y.real.unique())
print("Output imag part (should equal 0.25 if biasing were correct):", y.imag.unique())
