# Ноутбук с нейронкой для предсказаний

## Подготовка данных

Импортируем

In [None]:
!pip install pyffs
!pip install cupy

In [2]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision

from tqdm import tqdm
from copy import deepcopy
from timeit import default_timer
from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, log_loss
from sklearn.metrics import RocCurveDisplay, roc_curve, auc


import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch.utils.data import random_split

import gc
from PIL import Image
from pyffs import ffsn, iffsn
import cupy as cp
import scipy.stats as sps

Подключаем Google диск

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Загружаем данные: бинарные маски и симуляции дифракции по ним

In [4]:
sim_pattern = np.load('/content/drive/MyDrive/sim_pattern.npy', allow_pickle=True) / 8
simulations = np.load('/content/drive/MyDrive/simulations.npy', allow_pickle=True)
dataset = np.load('/content/drive/MyDrive/dataset.npy', allow_pickle=True)

Для работы с полным датасетом не хватает вычислительной мощности. Поэтому отрезаем от него часть.

In [16]:
device = torch.device('cuda')

N_start = 0
N_end = simulations.shape[0]
data_step = 2
X = torch.from_numpy(simulations[N_start:N_end:data_step].reshape(-1, 1, 513, 513)).half().to(device)
y = torch.from_numpy(dataset[N_start:N_end:data_step].reshape(-1, 1, 513, 513)).half().to(device)
Z = torch.from_numpy(sim_pattern[:500].reshape(-1, 1, 513, 513)).half().to(device)

Освобождаем память

In [None]:
del data

Делим данные на обучение, валидацию и тест

In [17]:
bs = 5

Z_val, Z_test = train_test_split(Z, test_size=0.5, random_state=42)
# X_train, X_val = train_test_split(X_train_val, test_size=0.1, random_state=42)

train_data = TensorDataset(X, y)
val_data = TensorDataset(Z_val)
test_data = TensorDataset(Z_test)

train_loader = DataLoader(train_data, batch_size=bs)
val_loader = DataLoader(val_data, batch_size=bs)
test_loader = DataLoader(test_data, batch_size=bs)

scaler = torch.cuda.amp.GradScaler()

## Симуляция ATL

Функции из ноутбука atl_simulation.ipynb

In [12]:
def my_matmul(a, m):
    if np.isscalar(a):
        return a * m
    else:
        return np.outer(a, m).reshape(len(a), m.shape[0], m.shape[1])
    
def compute_carpet(mask, wavelength, T_x, T_y, z):
    '''
    Функция для вычисления ковра Талбота
    param mask: изображение маски,
    param wavelength: длина волны [мм],
    param T_x: период маски по оси x [мм],
    param T_y: период маски по оси y [мм],
    param z: расстояние от картины до маски [мм]

    return: изображение картины
    '''

    mask_array = cp.sqrt(mask)
    
    N = mask_array.shape[0]
    n = (N - 1) // 2

    T = cp.array([T_x, T_y])
    T_c = T / 2
    N_FS = [N, N]
    
    F = ffsn((mask_array), T, T_c, N_FS)
    f_x = cp.reshape(cp.arange(-n, n + 1) / T[0], (1, -1))
    f_y = cp.reshape(cp.arange(-n, n + 1) / T[1], (-1, 1))

    H = cp.exp(-1j * cp.pi * z * my_matmul(wavelength, (cp.power(f_x, 2) + cp.power(f_y, 2))))
    result_array = (iffsn((F * H), T, T_c, N_FS, axes=[1, 2]))
    return cp.square(cp.absolute(result_array))

def compute_atl(mask, T_x = 0.001, T_y = 0.001, wl = 1.35e-5, rel_delta_wl = 0.04):
    '''
    Функция для вычисления дифракционной картины ATL
    param mask: изображение маски,
    param T_x: период маски по оси x [мм],
    param T_y: период маски по оси y [мм],
    param wl: длина волны [мм],
    param rel_delta_wl: относительная немонохроматичность [мм]

    return: изображение картины
    '''

    grid_size = 500
    c = 299792458e3 # speed of light in mm/s
    freq = c / wl
    delta_freq = freq * rel_delta_wl
    freq_grid = cp.linspace(freq - delta_freq, freq + delta_freq, grid_size)

    grid = cp.linspace(-2, 2, grid_size)
    intens = sps.norm.pdf(grid.get())
    
    z_A = 2 * max(T_x, T_y) ** 2 / (rel_delta_wl * wl)
    carp = compute_carpet(mask, (c / freq_grid), T_x, T_y, 2 * z_A)
    return cp.average(carp, axis=0, weights=intens).astype(cp.float16)

def compute_atl_4tensor(X):
    masks = cp.asarray(X.clone().detach(), dtype=cp.uint8).reshape(-1, 513, 513)
    pics = []
    for i in range(masks.shape[0]):
        pics.append(compute_atl(masks[i]).get())
    return torch.tensor(np.array(pics).reshape(-1, 1, 513, 513), requires_grad=True).half().to(device)


## Обучение модели

Функция, которая обучает модель

In [8]:
def train(model, criterion, optimizer, X, n_epochs, n_stop=10):
    min_val_loss = np.inf
    f = 0
    for epoch in range(n_epochs):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            _X, _y = batch
            _X, _y = _X.to(device), _y.to(device)
          
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                _y_pred = model(_X)
                loss = criterion(_y_pred, _y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        model.eval()
        mean_val_loss = 0
        batch_count = 0
        for batch in val_loader:
            _X = batch[0].to(device)
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                _y_pred = model(_X)
                _X_sim = compute_atl_4tensor(_y_pred)
                loss = criterion(_X, _X_sim)
            mean_val_loss += loss.item()
            batch_count += 1

        mean_val_loss /= batch_count

        if epoch % 10 == 0:
            print(f'Epoch: {epoch}, \tValidation loss: {(mean_val_loss)}')

        if round(mean_val_loss, 3) >= min_val_loss:
            f = f + 1
        else:
            f = 0
            best_model = deepcopy(model)
            min_val_loss = round(mean_val_loss, 3)
        if f >= n_stop:
            print(f'epoch: {epoch}, val loss did not decrease for {f} epoch(s)')
            break

    return best_model #, train_loss, val_loss

Функция для подсчета ошибки на тесте

In [9]:
def test_loss(model, criterion):
  mean_loss = 0
  batch_count = 0
  for batch in test_loader:
    with torch.no_grad():
      _X = batch[0].cpu().float()
      # with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
      model.cpu()
      _y_pred = model(_X)
      _X_sim = compute_atl_4tensor(_y_pred).cpu()
      loss = criterion(_X, _X_sim)
      mean_loss += loss.item()
      batch_count += 1
    
  return mean_loss / batch_count

Два варианта модели CNN

In [None]:
# Define the CNN architecture
ch_b = 4
class CNN1(torch.nn.Module):
  def __init__(self):
    super(CNN1, self).__init__()
    self.conv1 = torch.nn.Conv2d(1, ch_b, kernel_size=3, padding=2, dilation=2) # 513
    self.pool1 = torch.nn.MaxPool2d(kernel_size=2)                # 256
    self.conv2 = torch.nn.Conv2d(ch_b, ch_b * 2, kernel_size=3, padding=2, dilation=2) # 256
    self.pool2 = torch.nn.MaxPool2d(kernel_size=2)                # 128
    self.fc1 = torch.nn.Linear(ch_b * 2 * 128 * 128, 128)
    self.fc2 = torch.nn.Linear(128, 513 * 513)

  def forward(self, x):
    x = self.conv1(x)
    x = torch.relu(x)
    x = self.pool1(x)
    x = self.conv2(x)
    x = torch.relu(x)
    x = self.pool2(x)
    x = x.view(-1, ch_b * 2 * 128 * 128)
    x = self.fc1(x)
    x = torch.relu(x)
    x = self.fc2(x)
    x = torch.relu(x)
    return x.view(-1, 1, 513, 513)

In [14]:
# Define the CNN architecture
ch_b = 4
class CNN2(torch.nn.Module):
  def __init__(self):
    super(CNN2, self).__init__()
    self.conv1 = torch.nn.Conv2d(1, ch_b, kernel_size=6, padding=5, dilation=2) # 513
    self.pool1 = torch.nn.MaxPool2d(kernel_size=2)                # 256
    self.conv2 = torch.nn.Conv2d(ch_b, ch_b*2, kernel_size=6, padding=5, dilation=2) # 256
    self.pool2 = torch.nn.MaxPool2d(kernel_size=2)                # 128
    self.fc1 = torch.nn.Linear(ch_b * 2 * 128 * 128, 128)
    self.fc2 = torch.nn.Linear(128, 513 * 513)

  def forward(self, x):
    x = self.conv1(x)
    x = torch.relu(x)
    x = self.pool1(x)
    x = self.conv2(x)
    x = torch.relu(x)
    x = self.pool2(x)
    x = x.view(-1, ch_b * 2 * 128 * 128)
    x = self.fc1(x)
    x = torch.relu(x)
    x = self.fc2(x)
    x = torch.relu(x)
    return x.view(-1, 1, 513, 513)

В этой ячейке можно проверить работу слоя

In [None]:
# m = nn.MaxPool2d(kernel_size=4)
m = torch.nn.Conv2d(1, 1, kernel_size=3, padding=2, dilation=2, stride=2)
s = 256
input = torch.randn(1, s, s)
output = m(input)
output.size()

torch.Size([1, 128, 128])

Обучаем модель

In [18]:
# Instantiate the CNN
model = CNN2().to(device)
model.cuda()

# Define the loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

# with torch.autograd.detect_anomaly(check_nan=True):
model = train(model, criterion, optimizer, X, n_epochs=50, n_stop=10)

Epoch: 0, 	Validation loss: 276.046875
Epoch: 10, 	Validation loss: 276.046875
epoch: 10, val loss did not decrease for 10 epoch(s)


Можно загрузить уже обученную модель

In [None]:
model = torch.load('/content/drive/MyDrive/models/model2', map_location=torch.device('cpu')).to(device)
# model.cuda()

Считаем ошибку на тесте

In [None]:
criterion = torch.nn.MSELoss()
print(f'Test loss: {test_loss(model, criterion)}')

Test loss: 4123.008231026785


Сохраняем модель, если понравилась

In [None]:
torch.save(model, '/content/drive/MyDrive/models/model4123')

## Проверка модели на практике

Следующие ячейки нужны для предсказания маски для какой-то конкретной картинки. Это предсказание можно сохранить, сделать по нему симуляцию и сравнить ее с исходной картинкой. 

In [None]:
circle_np = np.array(Image.open('/content/drive/MyDrive/prediction_test/cooler_im10.png').convert("L")) / 64
# circle_np = np.load('/content/drive/MyDrive/dif/sim18.npy')
circle =  torch.from_numpy(circle_np.reshape(1, 513, 513)).half().to(device)
model.cuda()
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
  circle_pred = model(circle)

In [None]:
def frame_plot(arr, ax):
  rectX = np.linspace(0, 1, arr.shape[0])
  rectY = np.linspace(0, 1, arr.shape[1])
  x, y = np.meshgrid(rectX, rectY)
  ax.pcolormesh(x, y, arr)

In [None]:
circle_pred_np = circle_pred.cpu().detach().numpy().reshape(513, 513)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16, 8))
frame_plot(circle_pred_np, ax=ax[0])
frame_plot(circle_np, ax=ax[1])

In [None]:
circle_pred_np[256][256]

74.3

In [None]:
circle_im = Image.fromarray(circle_pred_np.astype(np.uint8))
circle_im = circle_im.convert('L')
circle_im.show()
circle_im.save('/content/drive/MyDrive/prediction_test/circle_pred.png')