In [None]:
import numpy as np
from tqdm import tqdm
from time import time

import torchvision
from torchvision import models, transforms

import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter

In [None]:
def accuracy(yhat,y):
    # si y encode les indexes
    if len(y.shape)==1 or y.size(1)==1:
        return (torch.argmax(yhat,1).view(y.size(0),-1)== y.view(-1,1)).double().mean()
    # si y est encodé en onehot
    return (torch.argmax(yhat,1).view(-1) == torch.argmax(y,1).view(-1)).double().mean()

def train(model,epochs,train_loader,test_loader,feature_extract=False):
    model = model.to(device)
    writer = SummaryWriter(f"{TB_PATH}/{model.name}")
    
    params_to_update = model.parameters()
    print("params to learn:")
    if feature_extract:
        params_to_update = []
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t",name)
    else:
        for name,param in model.named_parameters():
            if param.requires_grad == True:
                print("\t",name)
    optim = torch.optim.Adam(params_to_update,lr=1e-3)
    
    print(f"running {model.name}")
    loss = nn.CrossEntropyLoss()
    for epoch in tqdm(range(epochs)):
        cumloss, cumacc, count = 0, 0, 0
        model.train()
        for x,y in train_loader:
            optim.zero_grad()
            x,y = x.to(device), y.to(device)
            yhat = model(x)
            l = loss(yhat,y)
            l.backward()
            optim.step()
            cumloss += l*len(x)
            cumacc += accuracy(yhat,y)*len(x)
            count += len(x)
        writer.add_scalar('loss/train',cumloss/count,epoch)
        writer.add_scalar('accuracy/train',cumacc/count,epoch)
        if epoch % 1 == 0:
            model.eval()
            with torch.no_grad():
                cumloss, cumacc, count = 0, 0, 0
                for x,y in test_loader:
                    x,y = x.to(device), y.to(device)
                    yhat = model(x)
                    cumloss += loss(yhat,y)*len(x)
                    cumacc += accuracy(yhat,y)*len(x)
                    count += len(x)
                writer.add_scalar(f'loss/test',cumloss/count,epoch)
                writer.add_scalar('accuracy/test',cumacc/count,epoch)

def set_parameter_requires_grad(model, feature_extract):
    if feature_extract:
        for name,p in model.named_parameters():
            p.requires_grad = False    
                
def get_test_data(dataloader, size):
    X_test, Y_test = next(iter(dataloader))
    batch_size = len(X_test)
    n = size//batch_size
    for i, batch in enumerate(dataloader):
        if i < n:
            X_tmp, Y_tmp = batch
            X_test = torch.cat((X_test, X_tmp), 0)
            Y_test = torch.cat((Y_test, Y_tmp), 0)
    return X_test, Y_test

In [None]:
TB_PATH = "/tmp/logs/sceance2"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

FCNResNet = models.segmentation.fcn_resnet50(pretrained=True)

# FCNResNet.backbone["conv1"] = nn.Linear(2048, 1024)
print(FCNResNet.eval())

set_parameter_requires_grad(FCNResNet, True)

In [None]:
input_size = 224
batch_size = 128

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

transformFCNResNetTrain=transforms.Compose([ # Cette fois on utilise pas de grayscale car nous avons un gros modele pré-entrainé
        transforms.RandomResizedCrop(input_size), # selection aléatoire d'une zone de la taille voulue (augmentation des données en apprentissage)
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
transformFCNResNetTest=transforms.Compose([
        transforms.Resize(input_size), # selection de la zone centrale de la taille voulue
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

FCNResNet_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transformFCNResNetTrain)
FCNResNet_trainloader = torch.utils.data.DataLoader(FCNResNet_trainset, batch_size=batch_size, pin_memory=True, shuffle=True)

FCNResNet_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transformFCNResNetTest)
FCNResNet_testloader = torch.utils.data.DataLoader(FCNResNet_testset, batch_size=batch_size, pin_memory=True, shuffle=True)

In [None]:
## Entraînement du réseau
# FCNResNet.name = "FCNResNet"
# train(FCNResNet, 1, FCNResNet_trainloader, FCNResNet_testloader)

In [None]:
## Accuracy
X_test, Y_test = get_test_data(FCNResNet_testloader, 1000) 
X_test, Y_test = X_test.to(device), Y_test.to(device)
# print("Acc for FCNResNet transfer learning :", accuracy(FCNResNet(X_test), Y_test))

In [None]:
import cv2

label_map = [
               (0, 0, 0),  # background
               (128, 0, 0), # aeroplane
               (0, 128, 0), # bicycle
               (128, 128, 0), # bird
               (0, 0, 128), # boat
               (128, 0, 128), # bottle
               (0, 128, 128), # bus 
               (128, 128, 128), # car
               (64, 0, 0), # cat
               (192, 0, 0), # chair
               (64, 128, 0), # cow
               (192, 128, 0), # dining table
               (64, 0, 128), # dog
               (192, 0, 128), # horse
               (64, 128, 128), # motorbike
               (192, 128, 128), # person
               (0, 64, 0), # potted plant
               (128, 64, 0), # sheep
               (0, 192, 0), # sofa
               (128, 192, 0), # train
               (0, 64, 128) # tv/monitor
]

def image_overlay(image, segmented_image):
    alpha = 1 # transparency for the original image 
    beta = 0.8 # transparency for the segmentation map
    gamma = 0 # scalar added to each sum
    print(image.shape)
    image = np.transpose(image, (1, 2, 0))
    print(image.shape)
    segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)
    image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
#     cv2.addWeighted(image, alpha, segmented_image, beta, gamma, image)
    return image

def draw_segmentation_map(outputs):
    labels = torch.argmax(outputs.squeeze(), dim=0).detach().cpu().numpy()
    # create Numpy arrays containing zeros
    # later to be used to fill them with respective red, green, and blue pixels
    red_map = np.zeros_like(labels).astype(np.uint8)
    green_map = np.zeros_like(labels).astype(np.uint8)
    blue_map = np.zeros_like(labels).astype(np.uint8)
    
    for label_num in range(0, len(label_map)):
        index = labels == label_num
        red_map[index] = np.array(label_map)[label_num, 0]
        green_map[index] = np.array(label_map)[label_num, 1]
        blue_map[index] = np.array(label_map)[label_num, 2]
        
    segmentation_map = np.stack([red_map, green_map, blue_map], axis=2)
    return segmentation_map

def get_segment_labels(image, model, device):
    # transform the image to tensor and load into computation device
    image = np.transpose(image, (1, 2, 0))
    image = Image.fromarray(np.uint8((image)*255))
    transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])
    image = transform(image).to(device)
#     image = torch.permute(image, (2, 0, 1))
#     print(image.shape)
    image = image.unsqueeze(0) # add a batch dimension
    outputs = model(image)
    return outputs

In [None]:
from PIL import Image
from matplotlib import pyplot as plt
from torch import Tensor

X_test, Y_test = X_test.to(device), Y_test.to(device)
FCNResNet = FCNResNet.to(device)

X_test_np = np.array(X_test.cpu())
outputs = get_segment_labels(X_test_np[0], FCNResNet, device)
seg = draw_segmentation_map(outputs["aux"])


# X_test_np[0] = np.transpose(X_test_np[0], (1, 2, 0))
# image = Image.fromarray(np.uint8((X_test_np[0])*255))

img = image_overlay(X_test_np[0], seg)
img = np.transpose(img, (0,1,2))
print(img.shape)
plt.show(img)

In [None]:
print(X_test.shape)
for t in (20,40,60,80,100,120):
    t0 = time()
    FCNResNet(X_test[:t])
    print("FPS:", t, " --> seconds:", (time() - t0))

In [None]:
import os
PATH = "./"
torch.save(FCNResNet.state_dict(), os.path.join(PATH,"fcnresnet.pth"))