<a href="https://colab.research.google.com/github/handsomecoderyang/deep-learning-for-image-processing/blob/master/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install tensorboardX
# pip install utils
# from utils.vis_tools import Visualizer
 

SyntaxError: ignored

In [None]:
import torch
import pathlib
from torch.utils.data import Dataset
import numpy as np

class SliceData(Dataset):
  def __init__(self, root, transform):
    self.transform = transform
    self.examples = []
    files = list(pathlib.Path(root).iterdir())
    for fname in files:
      kspace = np.load(fname)
      num_slices = kspace.shape[0]
      self.examples += [(fname, slice) for slice in range(num_slices)] #文件+切片标号

  def __len__(self):
    return len(self.examples)
  
  def __getitem__(self, index):
    fname, slice_num = self.examples[index]
    data = np.load(fname)
    data_kspace = data[slice_num]
    data_kspace = torch.from_numpy(data_kspace)
    # target_kspace = torch.complex(data_kspace[:, :, 0], data_kspace[:, :, 1])
    return self.transform(data_kspace, fname=fname, slice='slice_num')



In [None]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F

from torch.nn.modules.dropout import Dropout
from torch.nn.modules.activation import LeakyReLU
class ConvBlock(nn.Module):
  def __init__(self, inchans:int, outchans: int, drop_prob: float):
    super().__init__()

    self.inchans = inchans
    self.outchans = outchans
    self.drop_prob = drop_prob

    self.layers = nn.Sequential(
        nn.Conv2d(self.inchans, self.outchans, kernel_size=3, stride=1, padding=1, bias=False),
        nn.InstanceNorm2d(self.outchans),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Dropout2d(drop_prob),

        nn.Conv2d(self.outchans, self.outchans, kernel_size=3, stride=1, padding=1, bias=False),
        nn.InstanceNorm2d(self.outchans),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Dropout2d(drop_prob),

    )
  def forward(self, image: torch.tensor):
    return self.layers(image)

class TransposeConvBlock(nn.Module):
  def __init__(self, inchans: int, outchans: int):
    super().__init__()

    self.inchans = inchans
    self.outchans = outchans

    self.layers = nn.Sequential(
        nn.ConvTranspose2d(self.inchans, self.outchans, kernel_size=2, stride=2, bias=False),
        nn.InstanceNorm2d(self.outchans),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
    )
  def forward(self, image: torch.tensor):
    return self.layers(image)

class Unet(nn.Module):
  #Une框架参考原始论文

  def __init__(self, in_chans: int, out_chans: int, chans: int = 32,
               num_pool_layers: int = 4, drop_prob: float = 0.0): #256 * 256
    super().__init__()

    self.inchans = in_chans
    self.out_chans = out_chans
    self.chans = chans
    self.num_pool_layers = num_pool_layers
    self.drop_prob = drop_prob
    self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
    ch = chans
    for _ in range(num_pool_layers -1):
      self.down_sample_layers.append(ConvBlock(ch, ch*2, drop_prob))
      ch *= 2
    self.conv = ConvBlock(ch, ch * 2, drop_prob)  

    self.up_conv = nn.ModuleList()
    self.up_transpose_conv = nn.ModuleList()
    for _ in range(num_pool_layers - 1):
      self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
      self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob))
      ch //= 2
    self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch))
    self.up_conv.append(
        nn.Sequential(
            ConvBlock(ch * 2, ch, drop_prob),
            nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1)
        )
    )

  def forward(self, image: torch.Tensor):
    #首先进行下采样
    stack = []  #保存每一层的输出结果
    output = image #表示输入
    for layer in self.down_sample_layers:
      print(layer)
      output = layer(output)
      stack.append(output)
      F.avg_pool2d(output, kernel_size=2, stride=2) #总共4层， 每层下降2倍， 总共下降16倍， 256/16 = 16, 所以此时output 大小为[batchsize, 512, 16, 16]
    self.conv(output)

    #接着进行上采样(转置卷积)
    for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):

      print(transpose_conv, '\n', conv)
      downsample_layer = stack.pop()
      output = transpose_conv(output)

      #做pad保证输出的和pop出的数据的大小相同 [N, chanel, H, W]
      padding = [0, 0, 0, 0] #左右上下
      if output.shape[-1] != downsample_layer.shape[-1]:
        padding[1] = 1
      if output.shape[-2] != downsample_layer.shape[-2]:
        padding[3] = 1
      if torch.sum(torch.tensor(padding)) != 0:
        output = F.pad(output, padding, "reflect")
      
      output = torch.cat((downsample_layer, output), dim=1)
      output = conv(output)             #[batchsize, 1, 256, 256]

    return output


In [None]:
import logging
import pathlib
import shutil
import time
import numpy as numpy
import torch
import torchvision
from tensorboardX import SummaryWriter
from torch.nn import functional as F
from torch.utils.data import DataLoader
import scipy.io as sio
import argparse
import os
import sys

# sys.path.remove("/content/drive/MyDrive/MRI")
# print(sys.path)
os.environ["CUDA_VISBLE_DIVICES"] = "0"
# import SliceData
# from utils.vis_tools import Visualizer
# from unet import Unet
print(pathlib.Path)

<class 'pathlib.Path'>


In [None]:
logging.basicConfig()
logger = logging.getLogger(__name__)



In [None]:
class DataTransform:
  def __init__(self, mask):
    self.mask = torch.from_numpy(mask)

  def __call__(self, kspace, fname, slice):
    target_kspace = kspace
    target_kspace = torch.complex(target_kspace[:, :, 0], target_kspace[:, :, 1])
    under_kspace = torch.mul(target_kspace, self.mask)

    target_img = torch.abs(torch.fft.ifft2(target_kspace))
    under_img = torch.abs(torch.fft.ifft2(under_kspace))

    return target_img, under_img

In [None]:
def create_datasets(args):
  mask = sio.loadmat("./mask/%s/%s/%s_256_256_%d.mat" % (args.data, args.mask, args.mask, args.rate))['Umask']
  train_data = SliceData(
      root=args.data_path/f'Train_part1',
      transform=DataTransform(mask)
  )
  dev_data = SliceData(
      root=args.data_path/f'Val',
      transform=DataTransform(mask)
  )
  return dev_data, train_data

In [None]:
def create_data_loaders(args):
  dev_data, train_data = create_datasets(args)
  display_data = [dev_data[i] for i in range(0, len(dev_data), len(dev_data))]

  train_loader = DataLoader(
      dataset=train_data,
      batch_size = args.batch_size,
      shuffle=True,
      num_workers=4,
      pin_memory=True,
  )
  dev_loader = DataLoader(
      dataset=dev_data,
      batch_size = args.batch_size,
      num_workers = 4,
      pin_memory=True,
  )
  display_loader = DataLoader(
      dataset=display_data,
      batch_size = 16,
      num_workers = 4,
      pin_memory=True,
  )

  return train_loader, dev_loader, display_loader



In [None]:
def train_epoch(args, epoch, model, data_loader, optimizer, writer):
  model.train()
  avg_loss = 0
  start_epoch = start_iter = time.perf_counter() #记录当前epoch开始的时间
  global_step = epoch * len(data_loader)
  for iter, data in enumerate(data_loader):
    under_img_tensor, target_img_tensor = data
    under_img_tensor, target_img_tensor = under_img_tensor.float(), target_img_tensor.float()

    under_img_tensor = under_img_tensor.unsqueeze(1).to(device)
    
    print(under_img_tensor.size(), target_img_tensor.size())

    output = model(under_img_tensor).squeeze(1)
    loss = F.l1_loss(output, target_img_tensor)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
    writer.add_scalar('TrainLoss', loss.item(), global_step + iter)

    if iter % args.report_interval == 0:
      logging.info(f'Epoch = [{epoch:3d}/{args.epochs:3d}]',
            f'Iter = [{iter:4d}/{len(data_loader):4d}]',
            f'Loss = {loss.item():.4g} Avg_loss = {avg_loss:.4g}'
            f'Time = {time.per_counter() - start_iter:.4f}s',
      )
      start_iter = time.perf_counter()
  return avg_loss, time.perf_counter() - start_epoch

In [None]:
def evaluate(args, epoch: int, model, data_loader, writer, vis):
  model.eval()
  losses = []
  start = time.per_counter()
  with torch.no_grad():
    for iter, data in enumerate(data_loader):
      under_img_tensor, target_img_tensor = data
      under_img_tensor = under_img_tensor.unsqueeze(1).to(device)
      target_img_tensor = target_img_tensor.to(device)

      output = model(under_img_tensor).squeeze(1)
      loss = F.mse_loss(output, target_img_tensor)
      losses.append(loss.item())
    writer.add_scalar('Dev_loss', np.mean(losses), epoch)
    if (vis != None):
      vis.plot("val Loss", np.mean(losses))

  return np.mean(losses), time.perf_counter() - start


In [None]:
def visualize(args, epoch, model, data_loader, writer, vis):
  def save_image(image, tag):
    image -= image.min()
    image /= image.max()
    grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
    writer.add_image(tag, grid, epoch)

  model.eval()
  with torch.no_grad():
    for iter, data in enumerate(data_loader):
      under_img_tensor, target_img_tensor = data
      under_img_tensor = under_img_tensor.unsqueeze(1).to(device)
      target_img_tensor = target_img_tensor.to(device)
      output = model(under_img_tensor)

      if(vis != None):
        for i in range(len(output)):
          vis.img("undersampled image - %d"%(i), under_img_tensor.squeeze(1)[i])
          vis.img("full image - %d"%(i), target_img_tensor[i])
          vis.img("recons image - %d"%(i), output[i])
        save_image(target_img_tensor.unsqueeze(1), 'Target')
        save_image(output, 'Reconstruction')
        save_image(torch.abs(target_img_tensor.unsqueeze(1) - output), 'Error')
      break

In [None]:
def save_model(args, exp_dir, epoch, model, optimizer, best_dev_loss, is_new_best):
  torch.save(
      {
          'epoch': epoch,
          'args': args,
          'model': model.state_dict(),
          'optimizer': optimizer.state_dict(),
          'best_dev_loss': best_dev_loss,
          'exp_dir': exp_dir
      },
      f = exp_dir / 'model.pt'
  )
  if is_new_best:
    shutil.copyfile(exp_dir / 'model.pt', exp_dir / 'best_model.pt')


In [None]:
def build_model(args):
  model = Unet(1, 1, 64, 4, 0).to(device)
  return model


In [None]:

def load_model(checkpoint_files):
  checkpoint = torch.load(checkpoint_files)
  args = checkpoint['args']
  model = build_model(args)
  model.load_state_dict(checkpoint['model'])
  optimizer = build_optim(args, model.parameters())
  optimizer.load_state_dict(checkpoint['optimizer'])
  return checkpoint, model, optimizer

In [None]:
def build_optim(args, params):
  optimizer = torch.optim.Adam(params, args.lr)
  return optimizer
  

In [None]:
def main(args, vis):
  args.exp_dir.mkdir(parents=True, exist_ok=True)
  writer = SummaryWriter(log_dir=args.exp_dir / 'summary')

  if args.resume:
    checkpoint, model, optimizer = load_model("./checkpoint")  #************************************************************************
    args = checkpoint['args']
    best_dev_loss = checkpoint['best_dev_loss']
    start_epoch = checkpoint['epoch']
    del checkpoint
  else:
    model = build_model(args)
    optimizer = build_optim(args, model.parameters())
    best_dev_loss = 1e9
    start_epoch = 0
  logging.info(args)
  logging.info(model)

  train_loader, dev_loader, display_loader = create_data_loaders(args)

  for epoch in range(start_epoch, args.epochs):
    train_loss, train_time = train_epoch(args, epoch, model, train_loader, optimizer, writer)

    dev_loss, dev_time = evaluate(args, epoch, model, dev_loader, writer, vis)
    visualize(args, epoch, model, display_loader, writer, vis)

    is_new_best = dev_loss < best_dev_loss
    save_model(args, args.exp_dir, epoch, model, optimizer, best_dev_loss, is_new_best)
    logging.info(
        f'epoch = [{epoch:4d}/{args.epochs:4d}] TrainLoss = {train_loss:.4g}'
        f'DevLoss = {dev_loss:.4g} TrainTime = {train_time:.4f}s DevTime = {dev_time:.4f}s',
    )
    writer.close() 

In [None]:

import argparse
def create_arg_parser():
  parser = argparse.ArgumentParser()
  parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs')
  parser.add_argument('--batch_size', default=1, type=int, help='the batch size')
  parser.add_argument('--lr', default=0.0005, type=float, help='the learning rate')
  parser.add_argument('--rate', default=20, type=int, choices=[5, 10, 20, 25], help='the undersampling rate')
  parser.add_argument('--mask', default='radial', type=str, choices=['catesian', 'radial', 'random'], help='the type of mask')
  parser.add_argument('--data', default='brain', type=str, choices=['brain', 'kenn'], help='which dataset(brain or knee')
  parser.add_argument('--report_interval', default=100, type=int, help='period of loss reporting')
  parser.add_argument('--exp_dir', default='cheakpoints', type=pathlib.Path, help='path where model and results should be saved')
  parser.add_argument('--resume', action='store_true', help='if set, resume the training from a previous model checkpoint.''"--cheakpoint" should be set with this')
  parser.add_argument('--checkpoint', default="/content/drive/MyDrive/MRI/", type=str, help='path to an existing checkpoint. used along with "--resume')
  parser.add_argument('--data_path', default="/content/drive/MyDrive/MRI/", type=pathlib.Path, required=False, help='path to the dataset')
  parser.add_argument('--device', type=str, default='cuda:0', help='which device to train on.set to "cuda:n", n represent GPU number')
  parser.add_argument('--use_visdom', type=bool, default=False, help='if true, watch loss and reconstruction on port http://localhost:8097')
  return parser

In [None]:
device = "cuda:0"
root_mri = '/content/drive/MyDrive/MRI/'
os.chdir(root_mri)
# !pwd
# !ls
vis = None
# args = create_arg_parser().parse_args()
args = create_arg_parser().parse_known_args()[0]
print(args)
main(args, vis)

Namespace(batch_size=1, checkpoint='/content/drive/MyDrive/MRI/', data='brain', data_path=PosixPath('/content/drive/MyDrive/MRI'), device='cuda:0', epochs=50, exp_dir=PosixPath('cheakpoints'), lr=0.0005, mask='radial', rate=20, report_interval=100, resume=False, use_visdom=False)


  cpuset_checked))


torch.Size([1, 1, 256, 256]) torch.Size([1, 256, 256])
ConvBlock(
  (layers): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout2d(p=0, inplace=False)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): LeakyReLU(negative_slope=0.2, inplace=True)
    (7): Dropout2d(p=0, inplace=False)
  )
)
ConvBlock(
  (layers): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout2d(p=0, inplace=False)
    (4): Conv2d(128, 128, kernel_size=(3, 3), str

RuntimeError: ignored

In [None]:
# !apt install psmisc
# !sudo fuser /dev/nvidia*
# !kill -9 73
