In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import torchvision.transforms.functional as F

import numpy as np
import os
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import re
from typing import NamedTuple

from utils.data import Data
from utils.manager import set_seed, set_cuda, fetch_paths, set_logger, set_device, RunManager
from utils.fourier import ifft2c as ift
from utils.fourier import fft2c as ft
from utils.math import complex_abs
from utils.mask import apply_random_mask
from utils.transform import Transform

from models.miccan import MICCAN

In [2]:
# ARGUMENTS
class Arguments():
    def __init__(self):
        self.acc = [1, 2, 3, 4]
        self.tnv = 20 #120
        self.mtype = 'random'
        self.dset = 'fastmribrain'
        self.seq_types = ["AXT1", "AXT2", "AXFLAIR"]        
        
args = Arguments()

In [3]:
set_seed()

In [4]:
data_path, _ = fetch_paths(args.dset)

In [5]:
train_transform = Transform(train=True, mask_type=args.mtype, accelerations=args.acc)
train_dataset = Data(root=data_path, train=True, seq_types=args.seq_types, transform=train_transform, nv=args.tnv)
print(f'Training set: No. of volumes: {train_dataset.num_volumes} | No. of slices: {len(train_dataset)}')
print(f'{train_dataset.data_per_seq[:-1]}')

Training set: No. of volumes: 120 | No. of slices: 1874
AXT1    : 40 | 602
AXT2    : 40 | 636
AXFLAIR : 40 | 636


In [6]:
args.bs = len(train_dataset)
train_loader = DataLoader(dataset=train_dataset, batch_size=args.bs, num_workers=0, shuffle=True, pin_memory=True)

In [7]:
batch = next(iter(train_loader))
batch.target.shape

torch.Size([1874, 1, 320, 320])

In [8]:
image = batch.target.squeeze().view(args.bs, -1)
image.shape

torch.Size([1874, 102400])

In [9]:
[U, S, V] = torch.pca_lowrank(image, q=None, center=True, niter=3)
pcs = torch.matmul(image, V[:, :2])
pcs.shape

torch.Size([1874, 3])

In [None]:
for i in range(pcs):
    