In [None]:
import torch, random, os, time
import numpy as np
import torch.nn as nn
import learn2learn as l2l
import torch.nn.functional as F
import torchio as tio
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
class SimpleUNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.enc1 = nn.Conv2d(in_channels, 32, 3, padding=1)
        self.enc2 = nn.Conv2d(32, 64, 3, padding=1)
        self.dec1 = nn.Conv2d(64, 32, 3, padding=1)
        self.out = nn.Conv2d(32, out_channels, 1)
    
    def forward(self, x):
        x1 = F.relu(self.enc1(x))
        x2 = F.relu(self.enc2(x1))
        x3 = F.relu(self.dec1(x2))
        return self.out(x3)

def signal_model(I0, T2, TE):
    return I0 * torch.exp(-TE.view(1, -1, 1, 1) / (T2 + 1e-6))

def relax_loss(pred_maps, kspace_ref, mask, TE):
    I0, T2 = pred_maps[:, 0:1], pred_maps[:, 1:2]
    signal = signal_model(I0, T2, TE)

    signal_k = torch.fft.fft2(signal, norm='ortho')
    signal_k_masked = signal_k * mask

    return F.mse_loss(signal_k_masked.real, kspace_ref.real) + \
           F.mse_loss(signal_k_masked.imag, kspace_ref.imag)

def meta_train_step(model, maml, tasks, TE):
    meta_loss = 0.0
    for support_x, support_k, query_x, query_k, mask in tasks:
        learner = maml.clone()
        pred_support = learner(support_x)
        loss_support = relax_loss(pred_support, support_k, mask, TE)
        learner.adapt(loss_support)

        pred_query = learner(query_x)
        loss_query = relax_loss(pred_query, query_k, mask, TE)
        meta_loss += loss_query
    meta_loss /= len(tasks)
    return meta_loss