<a href="https://colab.research.google.com/github/matthewchung74/not_so_blackout/blob/main/SRGAN_blackout.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Not so Blackout

This notebook demonstrates another technique for blacking out PHI

* based on https://colab.research.google.com/drive/1eV9BCLPiBrGllj1vQek2LZkOPuMMZPXa?usp=sharing

* first check gpu. things will run best on a `V100`

In [None]:
!nvidia-smi

In [None]:
!pip install Faker -Uq

### Mount drive

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    COLAB = True
    print("Note: using Google CoLab")
    %tensorflow_version 2.x
except:
    print("Note: not using Google CoLab")
    COLAB = False

### Crapify images

Set directories in drive and thenadd text to images to crapify them.

In [None]:
my_path = "/content/drive/MyDrive/data/gan/images/chest_xray"
chest_xray_text = "/content/drive/MyDrive/data/superres/images/chest_xray_text"
chest_xray_resized = "/content/drive/MyDrive/data/gan/images/chest_xray_resized"

In [None]:
import faker
import os
from PIL import Image, ImageDraw, ImageFont, ImageOps
from faker import Faker
import random

fake = Faker()

font_dir = "/usr/share/fonts/truetype/liberation"
fonts =  os.listdir(font_dir)

def adding_random_text(img):

  #top
  font = ImageFont.truetype(f"{font_dir}/{random.choice(fonts)}", random.randint(6,14))  
  text_layer = Image.new('L', img.size)
  draw = ImageDraw.Draw(text_layer)
  draw.text( (random.randint(0,img.size[0]*0.5), random.randint(0,5)), fake.name(),  font=font, fill=random.randint(150,255))

  rotated_text_layer = text_layer.rotate(0.0+90*random.randint(0,1), expand=1)
  img.paste( ImageOps.colorize(rotated_text_layer, (0,0,0), (255, 255,255)), None,  rotated_text_layer)

  #bottom
  font = ImageFont.truetype(f"{font_dir}/{random.choice(fonts)}", random.randint(6,14))  
  text_layer = Image.new('L', img.size)
  draw = ImageDraw.Draw(text_layer)
  draw.text( (random.randint(0,60), img.size[1]-random.randint(20,60)), fake.name(),  font=font, fill=random.randint(150,255))

  rotated_text_layer = text_layer.rotate(0.0+90*random.randint(0,1), expand=1)
  img.paste( ImageOps.colorize(rotated_text_layer, (0,0,0), (255, 255,255)), None,  rotated_text_layer)

  return img

# change to generate images
if False:
  for i, file in enumerate(os.listdir(my_path)):
      f_img = my_path+"/"+file
      d_img = chest_xray_text+"/"+file
      img = Image.open(f_img)
      img = img.convert('RGBA')
      img = img.resize((256,256))

      side = random.randint(0, 3)
      x1 = random.randint(0,25)
      y1 = random.randint(0,5)
      size = random.randint(6,14)
      img = adding_random_text(img).convert('RGB')
      img.save(d_img)

and view the crapified images

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
image_paths = os.listdir(chest_xray_text)[:4]
img_arr = []

for image_path in image_paths:
    img_arr.append(np.asarray(Image.open(f"{chest_xray_text}/{image_path}")))

fig = plt.figure(figsize=(20., 20.))
grid = ImageGrid(fig, 111, 
                 nrows_ncols=(2, 2),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes
                 )

for ax, im in zip(grid, img_arr):
    ax.imshow(im)


### Grab a Model
grabbed from https://colab.research.google.com/drive/1eV9BCLPiBrGllj1vQek2LZkOPuMMZPXa?usp=sharing

In [None]:
import torch
import math
from os import listdir
import numpy as np
from torch.autograd import Variable
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from os.path import join
from pathlib import Path

torch.autograd.set_detect_anomaly(True)

### Create Dataset objects

In [None]:
UPSCALE_FACTOR = 1
CROP_SIZE = 100

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# Now, I will load in some code for the dataset and dataloaders.
# Link to this notebook will be in the description, so you can get it from there
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

class TrainDatasetFromFolder(Dataset):
    # def __init__(self, dataset_dir, crop_size, upscale_factor):
    def __init__(self, chest_xray_text, chest_xray_resized):
        super(TrainDatasetFromFolder, self).__init__()
        self.chest_xray_text_folder = chest_xray_text
        self.chest_xray_resized_folder = chest_xray_resized

        self.chest_xray_resized = [join(chest_xray_resized, x) for x in listdir(chest_xray_resized) if is_image_file(x)]

        print(f"chest_xray_text_folder: {self.chest_xray_text_folder}")
        print(f"chest_xray_resized_folder: {self.chest_xray_resized_folder}")

        self.hr_transform = ToTensor()
        self.lr_transform = ToTensor()

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.chest_xray_resized[index]))
        # print(f"hr_image {hr_image.shape}")
        path = Path(self.chest_xray_resized[index])
        lr_image = self.lr_transform(Image.open(f"{self.chest_xray_text_folder}/{path.name}"))
        # print(f"lr_image {lr_image.shape}")
        return lr_image, hr_image

    def __len__(self):
        return len(self.chest_xray_resized)

In [None]:
BATCH_SIZE = 4
train_set = TrainDatasetFromFolder(chest_xray_text, chest_xray_resized)
trainloader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=1, shuffle=True)

In [None]:
from torch import nn, optim

class ResidualBlock(nn.Module):
  def __init__(self, channels):
    super(ResidualBlock, self).__init__()
    self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    self.bn1 = nn.BatchNorm2d(channels)
    self.prelu = nn.PReLU()
    self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(channels)
  def forward(self, x):
    residual = self.conv1(x)
    residual = self.bn1(residual)
    residual = self.prelu(residual)
    residual = self.conv2(residual)
    residual = self.bn2(residual)
    return x + residual
  
class UpsampleBlock(nn.Module):
  def __init__(self, in_channels, up_scale):
    super(UpsampleBlock, self).__init__()
    self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, 
                          kernel_size=3, padding=1)
    self.pixel_shuffle = nn.PixelShuffle(up_scale)
    self.prelu = nn.PReLU()
  def forward(self, x):
    x = self.conv(x)
    x = self.pixel_shuffle(x)
    x = self.prelu(x)
    return x

class Generator(nn.Module):
  def __init__(self, scale_factor):
    super(Generator, self).__init__()
    upsample_block_num = int(math.log(scale_factor, 2))

    self.block1 = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=9, padding=4),
        nn.PReLU()
    )

    self.block2 = ResidualBlock(64)
    self.block3 = ResidualBlock(64)
    self.block4 = ResidualBlock(64)
    self.block5 = ResidualBlock(64)
    self.block6 = ResidualBlock(64)
    self.block7 = nn.Sequential(
        nn.Conv2d(64, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64)
    )
    block8 = [UpsampleBlock(64, 2) for _ in range(upsample_block_num)]
    block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
    self.block8 = nn.Sequential(*block8)
  def forward(self, x):
    block1 = self.block1(x)
    block2 = self.block2(block1)
    block3 = self.block3(block2)
    block4 = self.block4(block3)
    block5 = self.block5(block4)
    block6 = self.block6(block5)
    block7 = self.block7(block6)
    block8 = self.block8(block1 + block7)
    return (torch.tanh(block8) + 1) / 2

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.net = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=3, padding=1),
        nn.LeakyReLU(0.2),

        nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),

        nn.Conv2d(64, 128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        nn.Conv2d(128, 256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256, 512, kernel_size=3, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        nn.AdaptiveAvgPool2d(1),
        nn.Conv2d(512, 1024, kernel_size=1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(1024, 1, kernel_size=1)
    )
  def forward(self, x):
    batch_size=x.size()[0]
    return torch.sigmoid(self.net(x).view(batch_size))

from torchvision.models.vgg import vgg16

# Now we got to make the Generator Loss
class TVLoss(nn.Module):
  def __init__(self, tv_loss_weight=1):
    super(TVLoss, self).__init__()
    self.tv_loss_weight=tv_loss_weight
  def forward(self, x):
    batch_size=x.size()[0]
    h_x = x.size()[2]
    w_x = x.size()[3]

    count_h = self.tensor_size(x[:, :, 1:, :])
    count_w = self.tensor_size(x[:, :, :, 1:])

    h_tv = torch.pow(x[:, :, 1:, :] - x[:, :, :h_x - 1, :], 2).sum()
    w_tv = torch.pow(x[:, :, :, 1:] - x[:, :, :, :w_x - 1], 2).sum()
    return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
  
  # Forgot to implement an important method
  @staticmethod # Must add this
  def tensor_size(t):
    return t.size()[1] * t.size()[2] * t.size()[3]

class GeneratorLoss(nn.Module):
  def __init__(self):
    super(GeneratorLoss, self).__init__()
    vgg = vgg16(pretrained=True)
    loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
    for param in loss_network.parameters():
      param.requires_grad = False
    self.loss_network = loss_network
    self.mse_loss = nn.MSELoss()
    self.tv_loss = TVLoss()
  def forward(self, out_labels, out_images, target_images):
    adversial_loss = torch.mean(1 - out_labels)
    perception_loss = self.mse_loss(out_images, target_images)
    image_loss = self.mse_loss(out_images, target_images)
    tv_loss = self.tv_loss(out_images)
    return image_loss + 0.001 * adversial_loss + 0.006 * perception_loss + 2e-8 * tv_loss

    

In [None]:
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
netG = Generator(UPSCALE_FACTOR)
netD = Discriminator()
generator_criterion = GeneratorLoss()

In [None]:
generator_criterion = generator_criterion.to(device)
netG = netG.to(device)
netD = netD.to(device)

N_EPOCHS = 150 
LR = 0.0002

optimizerG = optim.Adam(netG.parameters(), lr=LR)
optimizerD = optim.Adam(netD.parameters(), lr=LR)

results = {
    "d_loss":[],
    "g_loss":[],
    "d_score": [],
    "g_score": []
}

In [None]:
## Now for training code
from tqdm import tqdm
import os

### Train

In [None]:
import torchvision.transforms as T

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print(f"Saving checkpoint {filename}")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def test_rand_image(saved_pil_path):
  random_file = random.choice(os.listdir(chest_xray_resized))
  random_path = chest_xray_resized+"/"+random_file
  random_image_text = adding_random_text(Image.open(random_path))
  random_image = ToTensor()(random_image_text).to(device)
  random_image = random_image[None, :]
  random_processed_image = netG(random_image)

  cated = torch.cat((random_image.squeeze(), random_processed_image.squeeze()), 1)
  pil_image = T.ToPILImage()(cated)
  display(pil_image)
  pil_image.save(saved_pil_path)


for epoch in range(1, N_EPOCHS + 1):
  train_bar = tqdm(trainloader)
  running_results = {'batch_sizes':0, 'd_loss':0,
                     "g_loss":0, "d_score":0, "g_score":0}

  ctr = 0
  netG.train()
  netD.train()
  for data, target in train_bar:
    g_update_first = True
    batch_size = data.size(0)
    running_results['batch_sizes'] += batch_size

    real_img = Variable(target)
    real_img = real_img.to(device)
    z = Variable(data)
    z = z.to(device)

    ## Update Discriminator ##
    fake_img = netG(z)
    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph = True)
    optimizerD.step()
    
    ## Now update Generator
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()

    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    optimizerG.step()

    running_results['g_loss'] += g_loss.item() * batch_size
    running_results['d_loss'] += d_loss.item() * batch_size
    running_results['d_score'] += real_out.item() * batch_size
    running_results['g_score'] += real_out.item() * batch_size

    ## Updating the progress bar
    train_bar.set_description(desc="[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f" % (
        epoch, N_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
        running_results['g_loss'] / running_results['batch_sizes'],
        running_results['d_score'] / running_results['batch_sizes'],
        running_results['g_score'] / running_results['batch_sizes']
    ))

  save_checkpoint(netG, optimizerG, f"/content/drive/MyDrive/data/superres/experiments/g_{epoch}.pth")    
  save_checkpoint(netD, optimizerD, f"/content/drive/MyDrive/data/superres/experiments/d_{epoch}.pth")    
  netG.eval() 
  test_rand_image(f"/content/drive/MyDrive/data/superres/experiments/test_{epoch}.jpg")

### Check your work

In [None]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

epoch = 1
inference_netG = Generator(1)
load_checkpoint(f"/content/drive/MyDrive/data/superres/experiments/g_{epoch}.pth", inference_netG, optimizerG, lr=0.0002)

f = random.choice(os.listdir(my_path))
d_img = chest_xray_text+"/"+f
img = Image.open(d_img)
img = img.convert('RGBA')
img = img.resize((256,256))

before_img1 = train_set[0][0]
after_img1 = inference_netG( before_img1 )

before_img2 = adding_random_text(img).convert('RGB')
after_img2 = inference_netG(before_img1)

fig = plt.figure(figsize=(4., 4.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(1, 2),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, [before_img1, after_img1, before_img2, after_img2]):
    ax.imshow(im)

