In [2]:
from einops import rearrange
import matplotlib.pyplot as plt
import numpy as np
import torch
from pytorch_memlab import MemReporter

from invivo_data import load_data
from linop import SubspaceLinopFactory

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

%load_ext autoreload
%autoreload 2

## Data

In [3]:
ksp, trj, dcf, phi, mps = load_data(device, verbose=True)

ksp shape = torch.Size([600, 16, 1000]), dtype = torch.complex64
trj shape = torch.Size([600, 2, 1000]), dtype = torch.float64
dcf shape = torch.Size([600, 1000]), dtype = torch.float32
phi shape = torch.Size([4, 600]), dtype = torch.complex64
mps shape = torch.Size([16, 200, 200]), dtype = torch.complex64


## Create linops

In [4]:
linop_factory = SubspaceLinopFactory(trj, phi, mps, torch.sqrt(dcf))
linop_factory.to(device)
A, ishape, oshape = linop_factory.get_forward()
AH, _, _ = linop_factory.get_adjoint()
AHA, _, _ = linop_factory.get_normal(toeplitz=True, device=device, verbose=True)

> Running compute_weights...
>> Time: 0.4297477239742875 s
> Running compute_kernels...
>> Calculating kernel(0, 0)
>> Calculating kernel(1, 0)
>> Calculating kernel(2, 0)
>> Calculating kernel(3, 0)
>> Calculating kernel(0, 1)
>> Calculating kernel(1, 1)
>> Calculating kernel(2, 1)
>> Calculating kernel(3, 1)
>> Calculating kernel(0, 2)
>> Calculating kernel(1, 2)
>> Calculating kernel(2, 2)
>> Calculating kernel(3, 2)
>> Calculating kernel(0, 3)
>> Calculating kernel(1, 3)
>> Calculating kernel(2, 3)
>> Calculating kernel(3, 3)
>> Time: 161.0998094920069 s


In [5]:
# Check memory usage
reporter = MemReporter(linop_factory)
reporter.report()

Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
Tensor0                                   (16, 1000, 1, 600)    73.24M
Tensor1                                       (1000, 1, 600)     2.29M
Tensor2                                             (4, 600)    19.00K
Tensor3                                       (16, 200, 200)     4.88M
Tensor4                                      (600, 16, 1000)     0.00B
Tensor5                                       (600, 2, 1000)     9.16M
Tensor6                                          (600, 1000)     0.00B
trj                                           (600, 2, 1000)     0.00B
Tensor7                                              (6145,)    48.50K
Tensor8                                              (6145,)    48.50K
Tensor9                                                 (2,)   512.00B
Tensor10                                          

  fact_numel = tensor.storage().size()
  data_ptr = tensor.storage().data_ptr()
