# ライブラリーインストール

In [None]:
!pip install torch==1.13.1
!pip install torchvision==0.14.1
!pip install setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability
!pip install wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability

In [None]:
import numpy as np 
import pandas as pd

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim

from PIL import Image
import cv2
import albumentations as A

import statistics

import time
import os

from torchsummary import summary
import segmentation_models_pytorch as smp

import glob

from openfl.interface.interactive_api.federation import Federation
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment
from copy import deepcopy

#from tqdm.notebook import tqdm
import tqdm

torch.manual_seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
  print(torch.cuda.get_device_name())

In [None]:
MODEL_PATH = "drone_Trained.pth"#学習済みモデルファイルへのパス（相対パスでも絶対パスでもOK）

# 作成されたモデルファイルをを格納するフォルダを作成およびパスを変数にセット
!mkdir -p Custum_Model
SAVE_PATH = 'Custum_Model'

# データセット

In [None]:
mapping = {
    (0, 0, 0): 0,
    (150, 143, 9): 1,
}

def convert_rgb_to_value(target):
  h = target.shape[0]#画像の高さの取得
  w = target.shape[1]#画像の横幅の取得
  target = target.permute(2,0,1).contiguous()#テンソルの形を変換(H,W,C)->(C,H,W)
  mask = torch.empty(h, w, dtype=torch.long)#(H,W)の2次元型を用意
  
  for k in mapping:#マップで定義したクラスの数繰り返し検出する
    idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))#targetのある画素がmappingに定義し現在対象となっている値Kと一致すればTrueを返すテンソルidxを生成(C,H,W) 
                                                                                #(RGB値しか持たないtorch.tensor(k, dtype=torch.uint8)に対してunsueezeを2回行うことで比較可能な3次元形式に変形している)
    validx = (idx.sum(0) == 3)#validx:RGB値すべてが一致した場合True,一致しない場合Falseの2次元テンソル（H,W）
    mask[validx] = torch.tensor(mapping[k], dtype=torch.long)#validxがTrueだった場所に現在のKのRGB値を持つvalue値をmaskに代入
  
  return mask

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

t_train = A.Compose([
    A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST), 
    A.HorizontalFlip(), 
    A.VerticalFlip(), 
    A.GridDistortion(p=0.2), 
    A.RandomBrightnessContrast((0,0.5),(0,0.5)),
    A.GaussNoise()])

t_val = A.Compose([
    A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST), 
    A.HorizontalFlip(),
    A.GridDistortion(p=0.2)])

class DroneDataset(Dataset):
    
#    def __init__(self, img_path, mask_path, X, mean, std, transform=None):
#        self.img_path = img_path
#        self.mask_path = mask_path
#        self.X = X
#        self.transform = transform
#        self.mean = mean
#        self.std = std

    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
#        return len(self.X)
        return len(self.dataset)
    
    def __getitem__(self, idx):
#        img = cv2.imread(os.path.join(self.img_path, self.X[idx] + '.jpg'))
#        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#        mask = cv2.imread(os.path.join(self.mask_path, self.X[idx] + '.png'))
        img, mask = self.dataset[idx]
        
        if self.transform is not None:
            aug = self.transform(image=img, mask=mask)
            img = Image.fromarray(aug['image'])
            mask = aug['mask']
        
        if self.transform is None:
            img = Image.fromarray(img)
        
        t = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
        img = t(img)
        mask = torch.from_numpy(mask).long()

        mask = convert_rgb_to_value(mask)        
 
        return img, mask

In [None]:
class DroneDatasetInterface(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        self.train_set = DroneDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=t_train
        )
        self.valid_set = DroneDataset(
            self._shard_descriptor.get_dataset('val'),
            transform=t_val
        )
        
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        generator=torch.Generator()
        generator.manual_seed(0)
        return DataLoader(self.train_set, batch_size=self.kwargs['train_bs'], shuffle=True, generator=generator)

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, batch_size=self.kwargs['valid_bs'], shuffle=True) 

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)

In [None]:
fed_dataset = DroneDatasetInterface(train_bs=4, valid_bs=4)

# モデル定義

In [None]:
model = smp.Unet(
    'mobilenet_v2', 
    encoder_weights='imagenet', 
    classes=24, 
    activation=None, 
    encoder_depth=5, 
    decoder_channels=[256, 128, 64, 32, 16])

model.load_state_dict(torch.load(MODEL_PATH))

# 今回の学習済みモデルを使用しない場合は下記のようにモデルを定義する。クラスが2つあることに注意。「車」と「それ以外」。
#model = smp.Unet('mobilenet_v2', encoder_weights='imagenet', classes=2, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16])

In [None]:
max_lr = 1e-3
weight_decay = 1e-4

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)

In [None]:
framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model)

# 学習と検証

In [None]:
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

def mIoU(pred_mask, mask, smooth=1e-10, n_classes=23):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes): #loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0: #no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union +smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)

In [None]:
epoch = 1

#criterion = nn.CrossEntropyLoss()
def cross_entropy(output, target):
    """Binary cross-entropy metric
    """
    return F.cross_entropy(input=output,target=target)

#sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epoch, steps_per_epoch=len(train_loader))

In [None]:
task_interface = TaskInterface()

# Task interface currently supports only standalone functions.
@task_interface.register_fl_task(model='model', data_loader='train_loader', device='device', optimizer='optimizer')     
def train(model, train_loader, optimizer, device):
    torch.cuda.empty_cache()

    model.to(device)
    model.train()
    
    running_loss = 0
    iou_score = 0
    accuracy = 0
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    
    since = time.time()
    
    #training loop
    for image_tiles, mask_tiles in train_loader:
        #training phase
        #image_tiles, mask_tiles = data
        image = image_tiles.to(device)
        mask = mask_tiles.to(device)

        optimizer.zero_grad()

        #forward
        output = model(image)
        loss = F.cross_entropy(output, mask)

        #evaluation metrics
        iou_score += mIoU(output, mask)
        accuracy += pixel_accuracy(output, mask)

        #backward
        loss.backward()
        optimizer.step()
#        scheduler.step() 

        running_loss += loss.item()

    #calculatio mean for each batch
    mean_train_loss = running_loss/len(train_loader)
    mean_train_iou = iou_score/len(train_loader)
    mean_train_acc = accuracy/len(train_loader)
    print(
#        "Round:{}/{}..".format(e+1, epochs),  TODO
        "Train Loss: {:.3f}..".format(mean_train_loss),
        "Train mIoU:{:.3f}..".format(mean_train_iou),
        "Train Acc:{:.3f}..".format(mean_train_acc),
        "Time: {:.2f}m".format((time.time()-since)/60))
    
    return {'train_loss': mean_train_loss, 'train_iou': mean_train_iou, 'train_acc': mean_train_acc}
    
@task_interface.register_fl_task(model='model', data_loader='val_loader', device='device')     
def validate(model, val_loader, device, loss_fn=cross_entropy):
    torch.cuda.empty_cache()
    min_loss = np.inf
    decrease = 1
    not_improve=0

    model.to(device)
    model.eval()
    
    test_loss = 0
    test_accuracy = 0
    val_iou_score = 0
    
    val_loader = tqdm.tqdm(val_loader, desc="val")
    
    since = time.time()

    #validation loop
    with torch.no_grad():
        for image_tiles, mask_tiles in val_loader:
            #reshape to 9 patches from single image, delete batch size
            #image_tiles, mask_tiles = data
            image = image_tiles.to(device)
            mask = mask_tiles.to(device)
            
            output = model(image)
            
            #evaluation metrics
            val_iou_score +=  mIoU(output, mask)
            test_accuracy += pixel_accuracy(output, mask)
            
            #loss
            loss = loss_fn(output, mask)                                  
            test_loss += loss.item()


    if min_loss > (test_loss/len(val_loader)):
        print('Loss Decreasing.. {:.3f} >> {:.3f} '.format(min_loss, (test_loss/len(val_loader))))
        min_loss = (test_loss/len(val_loader))
        decrease += 1
        if decrease % 5 == 0:
            print('saving model...')
            torch.save(model.state_dict(), SAVE_PATH + '/Unet-Mobilenet_v2_mIoU-{:.3f}.pth'.format(val_iou_score/len(val_loader)))


    if (test_loss/len(val_loader)) > min_loss:
        not_improve += 1
        min_loss = (test_loss/len(val_loader))
        print(f'Loss Not Decrease for {not_improve} time')
        #if not_improve == 7:
            #print('Loss not decrease for 7 times, Stop Training')
            #break

    #calculatio mean for each batch
    mean_test_loss = test_loss/len(val_loader)
    mean_val_iou = val_iou_score/len(val_loader)
    mean_val_acc = test_accuracy/ len(val_loader)
    print(
#        "Round:{}/{}..".format(e+1, epochs), TODO
        "Val Loss: {:.3f}..".format(mean_test_loss),
        "Val mIoU: {:.3f}..".format(mean_val_iou),
        "Val Acc:{:.3f}..".format(mean_val_acc),
        "Time: {:.2f}m".format((time.time()-since)/60))
    
    return {'test_loss': mean_test_loss, 'val_iou': mean_val_iou, 'val_acc': mean_val_acc}

# 連合への接続

In [None]:
client_id = 'api'
director_node_fqdn = 'localhost'

# 1) TLS無しで接続（検証、PoC向け）
federation = Federation(
    client_id=client_id, 
    director_node_fqdn=director_node_fqdn, 
    director_port='50051', 
    tls=False)

In [None]:
# create an experimnet in federation
experiment_name = 'done_segmentation_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# print the default federated learning plan
import openfl.native as fx
print(fx.get_plan(fl_plan=fl_experiment.plan))

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=1,
    opt_treatment='CONTINUE_GLOBAL',
    override_config={'network.settings.agg_port': 50002}
)

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going
# fl_experiment.restore_experiment_state(model_interface)

fl_experiment.stream_metrics(tensorboard_logs=False)

In [None]:
best_model = fl_experiment.get_best_model()
torch.save(best_model.state_dict(), 'best_model.pth')

In [None]:
fl_experiment.remove_experiment_data()