In [None]:
%%capture
!pip install ai-edge-torch==0.4.0
!pip install gcpds-cv-pykit

In [None]:
import wandb
import ai_edge_torch
import torch
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from ai_edge_litert.interpreter import Interpreter
from gcpds_cv_pykit.datasets import FeetMamitas
from ai_edge_litert.interpreter import Interpreter
from gcpds_cv_pykit.baseline.dataloaders import Segmentation_DataLoader
from gcpds_cv_pykit.visuals import random_sample_visualization
from gcpds_cv_pykit.baseline.models import UNet

In [None]:
# Set a seed
seed = 42

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [None]:
FeetMamitas()

In [None]:
!cp -r /kaggle/input/feet-mamitas/Test /kaggle/input/feet-mamitas/Train /kaggle/input/feet-mamitas/Valid /kaggle/working/datasets

In [None]:
config = {"Dir of dataset": "/kaggle/working/datasets", "Batch size": 36, "Image size": (256,256), "Input size": (3,256,256),
          "Number of classes": 1, "Single class train": None, "Single class test": None, "Single class valid": None, "Images folder": "images",
          "Data augmentation": True, "Epochs": 61, "Device": "cuda:0", "AMixPre": False, "Model": "UNet", "Backbone": "resnet34",
          "Pretrained": True, "Activation function": "sigmoid", "Loss function": "DICE", "Save results": True, "Train phases": True,
          "Wandb monitoring": ["5092b6e1a9fb50d7448b32f2db4d0163bb5ef14f","MasterTests","FeetMamitas-UNet-DICE"]}

In [None]:
train_dataset = Segmentation_DataLoader(config["Dir of dataset"], config["Batch size"], config["Image size"], config["Number of classes"],
                                        "Train", config["Single class train"], config["Data augmentation"], config["Images folder"],pin_memory=False)

In [None]:
valid_dataset = Segmentation_DataLoader(config["Dir of dataset"], 1, config["Image size"], config["Number of classes"],
                                        "Valid", config["Single class valid"], config["Data augmentation"], config["Images folder"],pin_memory=False)

In [None]:
test_dataset = Segmentation_DataLoader(config["Dir of dataset"], config["Batch size"], config["Image size"], config["Number of classes"],
                                        "Test", config["Single class test"], config["Data augmentation"], config["Images folder"],pin_memory=False)

In [None]:
random_sample_visualization(dataset=train_dataset, num_classes=config["Number of classes"],
                            single_class=config['Single class train'],type='baseline')

### DICE Loss

In [None]:
wandb.login(key="5092b6e1a9fb50d7448b32f2db4d0163bb5ef14f")
run = wandb.init()
artifact = run.use_artifact('gcpds/MasterTests/best_model:v53', type='model')
artifact_dir = artifact.download()

In [None]:
model = UNet(in_channels=config["Input size"][0],out_channels=config["Number of classes"], final_activation=config["Activation function"])
state_dict = torch.load('/kaggle/working/artifacts/best_model:v53/best_model.pt', weights_only=True,map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
model_nhwc = ai_edge_torch.to_channel_last_io(model, args=[0])

In [None]:
def representative_dataset():
    for data, _ in valid_dataset:
        arr = data.numpy().astype(np.float32)
        arr = np.transpose(arr, (0, 2, 3, 1))
        yield [arr]

In [None]:
tfl_converter_flags = {
    'optimizations': [tf.lite.Optimize.DEFAULT],
    'representative_dataset': representative_dataset,
    'target_spec': {
        'supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS],
        'supported_types': [tf.float16] 
    }
}

In [None]:
C = config["Input size"][0]
sample_input = (torch.randn(1, 256, 256, C),)
edge_model = ai_edge_torch.convert(
    model_nhwc,
    sample_input,
    _ai_edge_converter_flags=tfl_converter_flags)

edge_model.export("model_fp16.tflite")

In [None]:
interpreter = Interpreter(model_path="/kaggle/working/model_fp16.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("=== INPUTS ===")
for i, d in enumerate(input_details):
    print(f"Input {i}: name={d['name']} shape={d['shape']} dtype={d['dtype']}")

print("\n=== OUTPUTS ===")
for i, d in enumerate(output_details):
    print(f"Output {i}: name={d['name']} shape={d['shape']} dtype={d['dtype']}")

In [None]:
image, mask = next(iter(valid_dataset))
print(image.shape)
image = image.numpy().transpose(0,2,3,1)

In [None]:
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

In [None]:
fig, axs = plt.subplots(1,2,figsize=(6,4))
axs[0].imshow(image[0])
axs[0].set_title('Image')
axs[0].axis('off')
axs[1].imshow(np.where(output_data[0][0,:,:]>0.5,1,0))
axs[1].set_title("TFLite model's inference")
axs[1].axis('off')