# Technical characteristics of the causal UNet with atrous conv.

In [None]:
! git clone https://github.com/nanopiero/CML_processing_by_ML.git

In [3]:
import random
import shutil
import os
from os import listdir as ls
from os.path import join, isdir, isfile
import matplotlib.pyplot as plt
import numpy as np
import re
import json
import copy
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import time
import importlib
import sys
from torch.utils.data.sampler import SubsetRandomSampler
import sys

# <font color='blue'> I) Model Loading </font>

In [4]:
sys.path.append('CML_processing_by_ML')
import src.utils.architectures_fcn
from src.utils.architectures import load_archi
from src.utils.architectures_fcn import UNet_causal_5mn_atrous

# example with exp69 (no final fc layers)
"""
python learning/preprocessing/train_1GPU_MAE_PNP.py lastepo UNet_causal_5mn_atrous -lr 0.001 -bs 128 -ne 100
-pr 20240816_exp69_CE_MSE_MAE_15 -ss 15  -lrc -miv -10 -comp -ste 330000 -sm 100  -mcl CE_MSE_MAE -ap 16 -rr -dbs -dbci
"""

# The previous command implies (see src/train_1GPU.py)
arch = "UNet_causal_5mn_atrous"
nchannels = 2
nclasses = 3
dilation=2
atrous_rates=[6, 12, 18, 24, 30, 36, 42]
additional_parameters = 16

model = load_archi(arch, nchannels, nclasses, size=64, dilation=1,
                   atrous_rates=atrous_rates, fixed_cumul=False,
                   additional_parameters=additional_parameters)

# <font color='blue'> II) Model description </font>



In [None]:
! pip install torchsummary
from torchsummary import summary
summary(model, (2, 10000))

In [None]:
# To see the computation graph :
! pip install torchviz
from torchviz import make_dot
example_input = torch.rand(1, 2, 10000)
output = model(example_input)
dot = make_dot(output, params=dict(list(model.named_parameters()))).render("fcn_torchviz", format="png")
from IPython.display import Image
Image(filename='fcn_torchviz.png')

# <font color='blue'> III) Memory usage, Inference time and FLOPs </font>

In [16]:
import torch
import copy
# Create a tensor of size [1, 10000, 2]
rand_500min_time_step_15sec = torch.rand(1, 2, 500 * 4)

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

# Calculate memory usage
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        model(rand_500min_time_step_15sec)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

In [None]:
# Inference time:
import timeit
def apply_model():
    with torch.no_grad():  # Ensuring no gradients are calculated
        model(rand_500min_time_step_15sec)

# Timing the function over 10 runs and calculating the mean time
times = timeit.repeat(apply_model, number=1, repeat=10)
mean_time = sum(times) / len(times)

print(f"Average inference time over 10 runs: {mean_time:.4f} seconds")

In [None]:
# Inference FLOPs
import torch
import torchvision.models as models
! pip install fvcore
from fvcore.nn import FlopCountAnalysis

# Define a model, e.g., ResNet
model2 = models.resnet18()

# Create a sample input tensor with the correct shape
input2 = torch.randn(1, 3, 224, 224)

# Perform FLOP count
flops = FlopCountAnalysis(model2, input2)
# flops = FlopCountAnalysis(model, rand_500min_time_step_15sec)

# Print total FLOPs
print(f"Total FLOPs: {flops.total()}")

# <font color='blue'> IV) Causality and receptive field </font>

In [None]:
def test_causality_and_receptive_field(model, input_tensor1, input_tensor2, pos):
    # Assuming model has been modified as above and is in eval mode for testing
    model.eval()
    output = model(input_tensor1) - model(input_tensor2)

    # Find the first and last non-zero outputs
    non_zero_indices = (output[0,0,:] != 0).nonzero(as_tuple=True)
    first_non_zero = non_zero_indices[0].min().item()
    last_non_zero = non_zero_indices[0].max().item()

    print("First non-zero output at index:", first_non_zero)
    print("Last non-zero output at index:", last_non_zero)
    print("length of receptive field:",  last_non_zero - first_non_zero)
    # Check causality
    if first_non_zero < pos:
        print("The model is not causal.")
    else:
        print("The model is causal.")

    return output

input_tensor1 = torch.rand(1, 2, 10000)
input_tensor2 = copy.deepcopy(input_tensor1)

# Set the 5000th position to 1 in both channels
pos = 6000
# pos = 6000
input_tensor1[0, :, pos] = 10.


output = test_causality_and_receptive_field(model, input_tensor1, input_tensor2, pos)

# The mode is causal once every 5 minutes (1 time step = 15 sec)