In [None]:
import os
import sys
import numpy as np
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from scipy.ndimage import zoom
import matplotlib.pyplot as plt

In [None]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
path = ""
sample = np.load(path, allow_pickle=True)
print(list(sample.keys()))

sample["axial_slice"]
plt.imshow(sample["axial_slice"].T, cmap="gray")

In [None]:
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)


# Import the model
from torch_interlacer.models import get_interlacer_residual_model

In [None]:
class MRIDataset(Dataset):
    """Dataset for MRI slices with undersampling."""
    
    def __init__(self, images, us_frac=0.75, input_domain='IMAGE', output_domain='IMAGE', transform=None, target_size=256):
        self.images = images
        self.us_frac = us_frac
        self.input_domain = input_domain
        self.output_domain = output_domain
        self.transform = transform
        self.target_size = target_size
        
        # Create undersampling mask for target size
        self.mask = self._create_undersampling_mask(target_size, us_frac)
        
    def _create_undersampling_mask(self, size, us_frac):
        """Create center-preserving undersampling mask."""
        mask = np.zeros((size, size), dtype=bool)
        band_size = 40
        center = size // 2
        keep_band = (center - band_size//2, center + band_size//2)
        
        # Keep center band
        mask[keep_band[0]:keep_band[1]+1, :] = True
        
        # Calculate how many extra lines to keep
        total_lines = size
        center_lines = keep_band[1] - keep_band[0] + 1
        target_lines = int((1 - us_frac) * total_lines)
        extra_lines = max(0, target_lines - center_lines)
        
        if extra_lines > 0:
            # Randomly select extra lines from outside the center band
            available_lines = (keep_band[0] - 0) + (size - keep_band[1] - 1)
            if available_lines > 0:
                extra_per_side = extra_lines // 2
                lines_to_keep = np.random.choice(available_lines, min(extra_per_side, available_lines), replace=False)
                
                # Add lines above center band
                above_lines = lines_to_keep[lines_to_keep < keep_band[0]]
                mask[above_lines, :] = True
                
                # Add lines below center band  
                below_lines = lines_to_keep[lines_to_keep >= keep_band[0]] - keep_band[0] + keep_band[1] + 1
                below_lines = below_lines[below_lines < size]
                mask[below_lines, :] = True
        
        return mask
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = self.images[idx]  # Original image data
        img = img.T
        # Convert to complex and take FFT to simulate k-space acquisition
        img_complex = img.astype(np.complex64)
        kspace = np.fft.fftshift(np.fft.fft2(img_complex))
        
        # Apply undersampling mask
        kspace_undersampled = kspace * self.mask
        
        # Convert back to image domain (corrupted image from undersampled k-space)
        img_undersampled = np.fft.ifft2(np.fft.ifftshift(kspace_undersampled))
        img_undersampled = np.real(img_undersampled)
        
        # Prepare input and output based on domains
        if self.input_domain == 'IMAGE':
            # Input: corrupted image (real + imaginary channels)
            input_data = np.stack([img_undersampled, np.zeros_like(img_undersampled)], axis=0)
        else:  # FREQ
            # Input: undersampled k-space
            input_data = np.stack([np.real(kspace_undersampled), np.imag(kspace_undersampled)], axis=0)
        
        if self.output_domain == 'IMAGE':
            # Output: original clean image
            output_data = np.stack([img, np.zeros_like(img)], axis=0)
        else:  # FREQ
            # Output: original k-space
            output_data = np.stack([np.real(kspace), np.imag(kspace)], axis=0)
        
        return torch.from_numpy(input_data).float(), torch.from_numpy(output_data).float()

In [None]:
def mri_safe_resize(dataset, target=(256, 256)):
    out_set = []
    for x in dataset:
        if x.shape == target:
            pass
        elif x.shape[0] < target[0]:
            # add vertical padding
            t_pad = (target[0] - x.shape[0]) // 2
            b_pad = t_pad + (target[0] - x.shape[0]) % 2
            # add horizontal padding
            l_pad = (target[1] - x.shape[1]) // 2
            r_pad = l_pad + (target[1] - x.shape[1]) % 2
            # check upper
            x = np.pad(x, pad_width=((t_pad, b_pad), (l_pad, r_pad)))
        else:
            # crop vertical
            t_crop = (x.shape[0] - target[0]) // 2
            b_crop = t_crop + (x.shape[0] - target[0]) % 2
            l_crop = (x.shape[1] - target[1]) // 2
            r_crop = l_crop + (x.shape[1] - target[1]) % 2
            # crop horizontal 
            x = x[t_crop:-b_crop, l_crop:-r_crop]
        out_set.append(x)
    return out_set

def normalize(dataset):
    # since arrays are normalized at this point, we can take mean of means without weights
    s = 0
    mx = dataset[0].flatten()[0]
    for i in range(len(dataset)):
        x = dataset[i]
        s += x.mean()
        cur_mx = x.max()
        if cur_mx > mx:
            mx = cur_mx
    mean = s / len(dataset)
    # normalize
    for i in range(len(dataset)):
        dataset[i] = (dataset[i] - mean)/ mx
    return dataset
        

In [None]:
# load data
TEST_SPLIT = 0.2

slice_files = glob.glob("data/combined_slices/*.npz")
print(slice_files)
temp = slice_files[0]
a = np.load(temp)
print(f"keys: {list(a.keys())}")

all_data = []
for fname in slice_files:
    all_data.append(np.load(fname)["axial_slice"])

print(f"size of dataset: {len(all_data)}")

n = len(all_data)
indices = np.arange(0, n)
np.random.shuffle(indices)
test_idxs = set(indices[:int(n * TEST_SPLIT)])
train_idxs = set(indices[int(n * TEST_SPLIT):])

test_data = [all_data[i].astype(np.float32) for i in test_idxs]
train_data = [all_data[i].astype(np.float32) for i in train_idxs]

test_data = mri_safe_resize(test_data)
train_data = mri_safe_resize(train_data)

test_data = normalize(test_data)
train_data = normalize(train_data)


print(train_data)

In [None]:
# model and train configs
batch_size = 16
us_frac = 0.4
input_domain = 'FREQ'
output_domain = 'FREQ'
lr = 0.001
num_epochs = 20

# Load data (comment out if you don't have data yet)
train_images, val_images = test_data, train_data
train_dataset = MRIDataset(train_images, us_frac, input_domain, output_domain, transform=None)
val_dataset = MRIDataset(val_images, us_frac, input_domain, output_domain, transform=None)
val_dataset.mask = train_dataset.mask # share masks
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 7))
input_sample, output_sample = train_dataset[0]

print(input_sample.shape)
print(output_sample.shape)
axes[0].imshow(input_sample[0].numpy().squeeze(), cmap="gray")
axes[0].set_title("Corruped Frequency Space - Axial Slice")
axes[0].axis("off")

corrupt = torch.fft.ifft2(torch.fft.ifftshift(torch.complex(input_sample[0], input_sample[1])))
# corrupt = corrupt.real / torch.max(corrupt.real)
axes[1].imshow(corrupt.real, cmap="gray")
axes[1].set_title("Corrupted Image")
axes[1].axis("off")


In [None]:
# Create the model
model = get_interlacer_residual_model(
    input_size=(2, 256, 256),
    nonlinearity='3-piece',
    kernel_size=9,
    num_features=32,
    num_convs=1,
    num_layers=10
)

# Move model to device
model = model.to(device)
print(f"Model created and moved to {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


In [None]:
# Setup loss function and optimizer
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.6f}')
    
    return total_loss / len(dataloader)

def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    total_loss = 0.0
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [None]:
# Training loop
best_val_loss = float('inf')
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    
    # Validate
    val_loss = validate_epoch(model, val_loader, criterion, device)
    
    print(f"Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"New best model saved! Val Loss: {val_loss:.6f}")

print(f"\nTraining completed!")
print(f"Best validation loss: {best_val_loss:.6f}")


In [None]:
def get_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    return 20 * np.log10(max_pixel / np.sqrt(mse))


In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['figure.facecolor'] = 'white'  # figure bg
mpl.rcParams['axes.facecolor']   = 'white'  # axes bg


In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 10))
input_sample, output_sample = val_dataset[131]
print(input_sample.shape)
print(output_sample.shape)
gt_tensor = output_sample.cpu().detach().squeeze()
gt_img = torch.fft.ifft2(torch.fft.ifftshift(torch.complex(gt_tensor[0], gt_tensor[1])))
gt_img = gt_img.real
axes[0].imshow(gt_img.numpy(), cmap="gray")
axes[0].set_title("Validation Set Axial Slice - GT")
axes[0].axis("off")

psnr_corrupt = get_psnr(gt_img.numpy(), corrupt.numpy())
corrupt = torch.fft.ifft2(torch.fft.ifftshift(torch.complex(input_sample[0], input_sample[1]))).real
psnr_corrupt = get_psnr(gt_img.numpy(), corrupt.numpy())
axes[1].imshow(corrupt.real, cmap="gray")
axes[1].set_title(f"Corrupted Image - PSNR: {psnr_corrupt:.4f}")
axes[1].axis("off")


with torch.no_grad():
    recon = model(input_sample.unsqueeze(0).to(device))
    recon = recon.cpu().detach().squeeze()
    print(recon.shape)
    recon_img = torch.fft.ifft2(torch.fft.ifftshift(torch.complex(recon[0], recon[1]))).real
    psnr_recon = get_psnr(gt_img.numpy(), recon_img.numpy().astype(np.float32))
    axes[2].imshow(recon_img.numpy(), cmap="gray")
    axes[2].set_title(f"Reconstructed Image - PSNR: {psnr_recon:.4f}")
    axes[2].axis("off")
    
print(get_psnr(gt_img.numpy(), corrupt.numpy()))
print(get_psnr(gt_img.numpy(), recon_img.numpy()))

In [None]:
plt.imshow(gt_img - corrupt)

In [None]:
plt.imshow(gt_img - recon_img.real)

In [None]:
plt.imshow(corrupt - recon_img.real)