In [None]:
%load_ext autoreload
%autoreload 2  

In [None]:
from torchsummary import summary

In [None]:
import matplotlib.pyplot as plt 

In [None]:
from torch.utils.tensorboard import SummaryWriter

In [None]:

import segmentation_models_pytorch as smp
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from quickannotator.dl.inference import run_inference, getPendingInferenceTiles
from quickannotator.dl.dataset import TileDataset
import io
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from safetensors.torch import save_file

def get_transforms(tile_size): #probably goes...elsewhere
    transforms = A.Compose([
    A.RandomScale(scale_limit=-.1, p=.1),
    A.PadIfNeeded(min_height=tile_size, min_width=tile_size),
    A.VerticalFlip(p=.5),
    A.HorizontalFlip(p=.5),
    #A.Blur(p=.5),
    # # Downscale(p=.25, scale_min=0.64, scale_max=0.99),
    #A.GaussNoise(p=.5, var_limit=(10.0, 50.0)),
    # A.GridDistortion(p=.5, num_steps=5, distort_limit=(-0.3, 0.3),
    #                 border_mode=cv2.BORDER_REFLECT),
    # A.ISONoise(p=.5, intensity=(0.1, 0.5), color_shift=(0.01, 0.05)),
    # A.RandomBrightnessContrast(p=0.5, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), brightness_by_max=True),
    # A.RandomGamma(p=.5, gamma_limit=(80, 120), eps=1e-07),
    # A.MultiplicativeNoise(p=.5, multiplier=(0.9, 1.1), per_channel=True, elementwise=True),
    # A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=10, val_shift_limit=10, p=.9),
    A.Rotate(p=1, border_mode=cv2.BORDER_REFLECT),
    A.RandomCrop(tile_size, tile_size),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Normalization
    ToTensorV2()])
    return transforms

In [None]:
classid = 2 
tile_size = 2_048 

boost_count = 5
batch_size_train=4
batch_size_infer=1
edge_weight=1_000
num_workers=0 

num_iters=10

dataset=TileDataset(classid,
                    edge_weight=edge_weight, transforms=get_transforms(tile_size), 
                    boost_count=boost_count)

dataloader = DataLoader(dataset, batch_size=batch_size_train, shuffle=False, num_workers=num_workers) #NOTE: for dataset of type iter - shuffle must == False

model = smp.Unet(encoder_name="efficientnet-b0", encoder_weights="imagenet", 
                 decoder_channels=(64, 64, 64, 32, 16), in_channels=3, classes=1) 

#model = model.half()

criterion = nn.BCEWithLogitsLoss(reduction='none', ).cuda()
optimizer = optim.AdamW(model.parameters(), lr=.01, weight_decay=1e-2) #TODO: this should be a setting

device = 'cuda' if torch.cuda.is_available() else "cpu"
model.to(device)
model.train()



In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False # 
    
for name, param in model.named_parameters():
    if not param.requires_grad:
        print(f"{name} is frozen")

In [None]:
#_ = summary(model, (3, tile_size, tile_size),dtypes=[torch.float16]) #TODO: log this
_ = summary(model, (3, tile_size, tile_size)) #TODO: log this

In [None]:
import datetime
running_loss = []
writer = SummaryWriter(log_dir=f"/tmp/{classid}/{datetime.datetime.now().strftime('%b%d_%H-%M-%S')}")

scaler = torch.amp.GradScaler("cuda")


for niter in tqdm(range(num_iters)): #TODO: this should be a setting
    images, masks, weights = next(iter(dataloader))
    #print ("post next iter")
    images = images.half().to(device)
    masks = masks.to(device)
    weights = weights.to(device)
    #print ("post copy ")
    epsilon = 1e-6
    for _ in tqdm(range(5),leave=False):
        
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            outputs = model(images) 
            loss = criterion(outputs, masks.float())
            loss = (loss * (edge_weight ** weights).type_as(loss)).mean()

            positive_mask = (masks == 1).float()
            unlabeled_mask = (masks == 0).float()

            positive_loss = 1.0 * (loss * positive_mask).mean()
            unlabeled_loss = .1 * (loss * unlabeled_mask).mean()

            loss_total = positive_loss + unlabeled_loss
        
        #loss_total.backward()
        scaler.scale(loss_total).backward()

        #optimizer.step()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        
    running_loss.append(loss_total.item())

    writer.add_scalar(f'loss/loss', loss, niter)
    writer.add_scalar(f'loss/positive_loss', positive_loss, niter)
    writer.add_scalar(f'loss/unlabeled_loss', unlabeled_loss, niter)
    writer.add_scalar(f'loss/loss_total', loss_total, niter)
    
    #print ("losses:\t",loss_total,positive_mask.sum(),positive_loss,unlabeled_loss)

    
    if niter % 50==0:
        #print (f"niter [{niter}], Loss: {sum(running_loss)/len(running_loss)}")
        running_loss=[]

        #print ("saving!") #TODO: do we want to *always* override the last saved model , or do we want to instead only save if some type of loss threshold is met?
                            #another potentially more interesting option is to do both, save on a regular basis (since if we things crash we can revert back othe nearest checkpoint\
                            #but as well give the user in the front end a dropdown which enables them to select which model checkpoint they want to use? we had somethng similar in QAv1
                            #that said, this is likely a more advanced features and not very "apple like" since it would require explaining to the user when/why/how they should use the different models
                            #maybe suggest avoiduing for now --- lets just save the last one
        save_file(model.state_dict(), f"/tmp/model_{classid}.safetensors") #TODO: needs to go somewhere reasonable maybe /projid/models/classid/ ? or something
        last_save = 0



In [None]:
images.shape

In [None]:
tds=TileDataset(classid,
                    edge_weight=edge_weight, transforms=None, 
                    boost_count=boost_count)


In [None]:
images, masks, weights= next(iter(tds))
plt.imshow(images)
plt.show()
plt.imshow(masks.squeeze())


In [None]:
cv2.imwrite("mask2.png",masks.squeeze()*255)
cv2.imwrite("img2.png",images)

In [None]:
images.shape

In [None]:
masks.shape

In [None]:
images, masks, weights = next(iter(dataloader))
i=images.cpu().detach()[0,::]
m=masks.cpu().detach()[0,::]
plt.imshow(i.squeeze().permute(1,2,0))
plt.show()
plt.imshow(m.squeeze())
print(masks.max())


In [None]:
masks.max()

In [None]:
plt.imshow(masks[0,::].cpu().detach().squeeze())  

In [None]:
outputs = model(images.to(device))

In [None]:
o=torch.sigmoid(outputs.detach()).cpu()

In [None]:
plt.imshow(o[0,::].squeeze()>.5)

In [None]:
plt.hist(o.squeeze().numpy().flatten())

In [None]:
masks[0,::].cpu().detach().squeeze()*255

In [None]:
cv2.imwrite("gt_io.png",i.squeeze().permute(1,2,0).numpy()*128)

cv2.imwrite("gt.png",masks[0,::].cpu().detach().numpy().squeeze()*255)  
cv2.imwrite("out.png", (o[0,::].numpy().squeeze()>.5)*255)

In [None]:
# import ttach as tta
# tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.five_crop_transform(crop_height=2_016,crop_width=2_016), merge_mode='mean')
# outputs = tta_model(images.to(device))
# o=torch.sigmoid(outputs.detach()).cpu()
# plt.imshow(o[0,::].squeeze()>.5)

In [None]:
masks.shape