In [1]:
import torch
import taichi as ti
from torch import optim
from diff_tem.simulation_grad import Simulation
import logging
from copy import deepcopy
from focal_frequency_loss import FocalFrequencyLoss as FFL
from torch.utils.tensorboard import SummaryWriter
import os

[Taichi] version 1.5.0, llvm 15.0.1, commit 7b885c28, win, python 3.9.16


In [2]:
os.chdir("diff_tem/tests/")
torch.set_default_dtype(torch.float64)
ti.init(arch=ti.vulkan)

[Taichi] Starting on arch=vulkan


In [3]:
logging.basicConfig(level=logging.INFO)
logging.getLogger().setLevel(logging.INFO)
_logger = logging.getLogger(__name__)
def _write_log(string, *logging_strings):
    strings = [string, *logging_strings]
    if len(strings) > 1:
        logging_string = "\n    ".join(strings)
        _logger.debug(f"{_logger.name}:\n    {logging_string}", stacklevel=2)
    else:
        _logger.debug(f"{_logger.name}:  {strings[0]}", stacklevel=2)
        
writer = SummaryWriter()

# Learn from empty scene

In [4]:
file_path = "Blank/input.txt"
simulation = Simulation.construct_from_file(file_path)
simulation.finish_init_()
# copy detectors
ref_detectors = simulation.detectors
learned_detectors = deepcopy(ref_detectors)

In [5]:
# set requires_grad
learning_parameter_names = ["mtf_a", "mtf_b", "mtf_c", "mtf_alpha", "mtf_beta", "mtf_p", "mtf_q", "gain"]
for d in learned_detectors:
    d.parameters_require_grad = True
    
# randomize
for d in learned_detectors:
    for parameter_name in learning_parameter_names:
        val = getattr(d, parameter_name).detach()
        setattr(d, parameter_name, val * torch.rand_like(val))

In [6]:
learned_noisy_detector, learned_noise_free_detector = learned_detectors
ref_noisy_detector, ref_noise_free_detector = ref_detectors

In [7]:
for pn in learning_parameter_names:
    _write_log(f"Ref {pn} = {getattr(ref_noisy_detector, pn)}, ",
               f"initial learning {pn} (from noisy ref) = {getattr(learned_noisy_detector, pn)}, ",
               f"initial learning {pn} (from clean ref) = {getattr(learned_noise_free_detector, pn)}")

In [8]:
iterations = 20
sample_num = 20

In [9]:
lr = 0.1
optimizer_noisy = optim.Adam([getattr(learned_noisy_detector, "_"+pn) for pn in learning_parameter_names], lr=lr)
optimizer_noise_free = optim.Adam([getattr(learned_noise_free_detector, "_"+pn) for pn in learning_parameter_names],
                                  lr=lr)

In [10]:
ffl = FFL(loss_weight=1.0, alpha=1.0)

In [15]:
ref_results, predetector_vvfs, snode_trees = simulation.generate_micrographs()
ref_noisy_predetector_vvfs, ref_noise_free_predetector_vvfs = predetector_vvfs

# get references
ref_noisy_results = []
ref_noise_free_results = []
for s in range(sample_num):
    noisy_vvf_values = []
    noise_free_vvf_values = []
    for np_vvf, nfp_vvf in zip(ref_noisy_predetector_vvfs, ref_noise_free_predetector_vvfs):
        noisy_vvf = ref_noisy_detector.apply_quantization(np_vvf.values)
        noisy_vvf = ref_noisy_detector.apply_mtf(noisy_vvf)
        noisy_vvf_values.append(noisy_vvf)

        noise_free_vvf = ref_noise_free_detector.apply_mtf(nfp_vvf.values)
        noise_free_vvf_values.append(noise_free_vvf)

    ref_noisy_results.append(torch.stack(noisy_vvf_values))
    ref_noise_free_results.append(torch.stack(noise_free_vvf_values))
ref_noisy_results = torch.stack(ref_noisy_results)  # (sample_num. tilt_num, *resolution)
ref_noise_free_results = torch.stack(ref_noise_free_results)  # (sample_num. tilt_num, *resolution)

In [16]:
i = 0

In [17]:
for _ in range(iterations):
    optimizer_noisy.zero_grad()
    optimizer_noise_free.zero_grad()
    pred_noisy_results = []
    pred_noise_free_results = []
    for s in range(sample_num):
        noisy_vvf_values = []
        noise_free_vvf_values = []
        for np_vvf, nfp_vvf in zip(ref_noisy_predetector_vvfs, ref_noise_free_predetector_vvfs):
            noisy_vvf = learned_noisy_detector.apply_quantization(np_vvf.values)
            noisy_vvf = learned_noisy_detector.apply_mtf(noisy_vvf)
            noisy_vvf_values.append(noisy_vvf)

            noise_free_vvf = learned_noise_free_detector.apply_mtf(nfp_vvf.values)
            noise_free_vvf_values.append(noise_free_vvf)

        pred_noisy_results.append(torch.stack(noisy_vvf_values))
        pred_noise_free_results.append(torch.stack(noise_free_vvf_values))
        
    pred_noisy_results = torch.stack(pred_noisy_results)  # (sample_num, tilt_num, *resolution)
    pred_noise_free_results = torch.stack(pred_noise_free_results)  # (sample_num, tilt_num, *resolution)
    loss_noisy = ffl(pred_noisy_results.real, ref_noisy_results.real)
    loss_noise_free = ffl(pred_noise_free_results.real, ref_noise_free_results.real)
    loss_noisy.backward()
    loss_noise_free.backward()
    for pn in learning_parameter_names:
        ref_parameter_val = getattr(ref_noisy_detector, pn)
        learned_parameter_val_from_noisy = getattr(learned_noisy_detector, pn)
        learned_parameter_val_from_noise_free = getattr(learned_noise_free_detector, pn)
        
        writer.add_scalars(f"Parameters/{pn}", {
            "learned_from_noisy": learned_parameter_val_from_noisy.detach(),
            "ref": ref_parameter_val
        }, i)
        writer.add_scalars("Training Loss", {
            "Noisy": loss_noisy.detach(),
        }, i)
        
    optimizer_noisy.step()
    optimizer_noise_free.step()
    i += 1

In [19]:
%load_ext tensorboard
%tensorboard --logdir=runs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 37632), started 0:26:19 ago. (Use '!kill 37632' to kill it.)