In [1]:
import dagshub
from xrkit.base import CONFIG
import torch
import os

import pytorch_lightning as L
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from torch.utils.data import DataLoader
from xrkit.data import SegmentationDataset
from xrkit.models import *

os.chdir("..")
torch.set_float32_matmul_precision("high")
dagshub.init(CONFIG.dagshub.repository_name, CONFIG.dagshub.repository_owner, mlflow=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_dataset = SegmentationDataset("train")
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG.base.batch_size,
    shuffle=False,
    num_workers=CONFIG.base.n_workers,
    pin_memory=True,
    drop_last=False,
)

validation_dataset = SegmentationDataset("validation")
validation_loader = DataLoader(
    validation_dataset,
    batch_size=CONFIG.base.batch_size,
    shuffle=False,
    num_workers=CONFIG.base.n_workers,
    pin_memory=True,
    drop_last=False,
)


test_dataset = SegmentationDataset("test")
test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG.base.batch_size,
    shuffle=False,
    num_workers=CONFIG.base.n_workers,
    pin_memory=True,
    drop_last=False,
)

In [3]:
from PIL import Image

inputs, targets = next(iter(train_loader))
outputs = DenseNet201UNetModel(n_epochs=1).network(inputs)

tensor_normalized = (outputs[0] - outputs[0].min()) / (outputs[0].max() - outputs[0].min())
tensor_image = tensor_normalized * 255

image_stack = []

for batch in range(tensor_image.shape[0]):
    pil_image = Image.fromarray(tensor_image[batch].detach().numpy())

    image_stack.append(pil_image)

for batch, img in enumerate(image_stack):
    img.show()

In [None]:
from torchvision import models

models.densenet201()

In [None]:
epochs = 100
# model = DenseNet201UNetModel(n_epochs=epochs)
# model = NASNetLargeUNetModel(n_epochs=epochs)
# model = ResNet152V2UNetModel(n_epochs=epochs)
model = VGG19UNetModel(n_epochs=epochs)

experiment_name = model.__class__.__name__.lower()[:-5]
metric, mode = "validation_loss", "min"

logger = MLFlowLogger(experiment_name=experiment_name, tracking_uri=CONFIG.dagshub.tracking_uri)

checkpoint_callback = ModelCheckpoint(
    monitor=metric,
    dirpath=f"models/{experiment_name}",
    filename="model-{epoch:03d}-{validation_loss:.2f}",
    save_top_k=1,
    mode=mode,
    enable_version_counter=False,
)

early_stop_callback = EarlyStopping(monitor=metric, min_delta=0.00, patience=10, mode=mode)

trainer = L.Trainer(max_epochs=epochs, logger=logger, callbacks=[checkpoint_callback, early_stop_callback])
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=validation_loader)

In [None]:
trainer.test(model=model, dataloaders=test_loader)

In [None]:
n_channels = 1
input = torch.randn((4, n_channels, 128, 128))
model = InceptionV3(task="segmentation", n_inputs=n_channels)
print(model(input).shape)

from torchview import draw_graph
import graphviz

# graphviz.set_jupyter_format("png")

# draw_graph(model, input_size=input.shape, expand_nested=True, depth=3).visual_graph

In [None]:
import timm

model = timm.create_model("resnet152d", pretrained=True, num_classes=3)

In [None]:
from torchinfo import summary
from xrkit.base import CONFIG


batch_size = CONFIG.base.batch_size
model
summary(nn.Sequential(*list(model.children()))[:-2], input_size=(batch_size, 3, 512, 512))

In [None]:
from torchview import draw_graph
import graphviz

graphviz.set_jupyter_format("png")

draw_graph(models.inception_v3(), input_size=input.shape, expand_nested=True, depth=3).visual_graph

In [None]:
input = torch.randn((4, 3, 256, 256))
models.inception_v3()(input)

In [None]:
import timm
import torch.nn as nn

input = torch.randn((4, 3, 128, 128))
model = timm.create_model("legacy_xception", pretrained=True, num_classes=10000)
model = nn.Sequential(*list(model.children()))[-2:]
model

In [None]:
network = models.vgg19()
network.classifier

In [None]:
from torchview import draw_graph
import graphviz

graphviz.set_jupyter_format("png")

draw_graph(model.network, input_size=(4, 1, 256, 256), expand_nested=True, depth=2).visual_graph

In [None]:
# # Resume Training
# trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=validation_loader, ckpt_path='models/unet/model-epoch=000-validation_fbeta_score=0.63.ckpt')

In [None]:
# model = UNetModel.load_from_checkpoint("models/unet/model-epoch=000-validation_fbeta_score=0.00.ckpt")
# print(model)

In [None]:
# checkpoint = torch.load("models/unet/model-epoch=000-validation_fbeta_score=0.00.ckpt")

# # Extraindo o estado do dicionário
# state_dict = checkpoint["state_dict"]

# # Carregando a estrutura do modelo
# k = UNetModel(n_epochs=epochs)

# # Carregando o estado do dicionário no modelo
# k.load_state_dict(state_dict)

In [None]:
# trainer2 = L.Trainer(logger=logger)
# trainer2.test(model=model, dataloaders=test_loader)