# Environment setup

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
project_id = 'stairnet'
!gcloud config set project {project_id}

# Test to see if dataset location is correct
! gsutil ls -al gs://stairnet_bucket/

In [None]:
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse

!mkdir data
!gcsfuse --implicit-dirs stairnet_bucket data

In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:
!pip install cloud-tpu-client==0.10 torch==1.11.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl
!pip install timm
!pip install  torchvideo

# FLAGS

In [None]:
# Define Parameters
FLAGS = {}
FLAGS['train_dataset_path'] = 'data/StairNet_Seq_5/SplitsVideo_numpy/Train/'
FLAGS['val_dataset_path'] = 'data/StairNet_Seq_5/SplitsVideo_numpy/Val/'
FLAGS['test_dataset_path'] = 'data/StairNet_Seq_5/SplitsVideo_numpy/Test/'
FLAGS['batch_size'] = 8
FLAGS['num_workers'] = 1 #- DataLoader worker (pid(s) 9180, 9220) exited unexpectedly
FLAGS['learning_rate'] = 0.0001
FLAGS['momentum'] = 0.5
FLAGS['num_epochs'] = 10
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = True

# Imports

In [None]:
import numpy as np
import os
import gc
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
from torch.utils.data.distributed import DistributedSampler
import torch_xla.utils.utils as xu
from torchvision import datasets, transforms
import timm
import torchvideo
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

#os.environ['XLA_USE_BF16']="1"
#os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

# Model

In [None]:
class ClassificationModel(nn.Module):
    def __init__(self, n_classes, encoder_name = 'mobilenetv3_large_100', hidden_size= 256, rnn_hidden=128):
        super().__init__()

        self.num_labels = n_classes

        # creating pretrained encoder model
        self.encoder = timm.create_model(
            encoder_name, 
            pretrained=True, 
            #features_only=True,
        ).eval().requires_grad_(False)
        
        self.map_to_seq = nn.Linear(
            self.encoder.conv_head.out_channels, hidden_size
        )
        self.lstm = nn.LSTM(
            hidden_size, rnn_hidden, bidirectional=True
        )
        self.dense = nn.Linear(2 * rnn_hidden, n_classes)
    
    def encode(self, x):
        x = x.permute(0, 4, 1, 2, 3)
        features = torch.stack([
            self.encoder.forward_features(x[:, :, i]).squeeze(3).squeeze(2) for i in range(x.shape[2])
        ])
        return features

    def forward(self, x):
        # encode input sequence into a feature space 
        conv = self.encode(x) # (width, batch, feature)
        seq = self.map_to_seq(conv)

        recurrent, (hidden, cell) = self.lstm(seq)
        # concat the final forward and backward hidden state
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)

        output = self.dense(hidden)
        return output # shape: (batch, num_class)

# Only instantiate model weights once in memory.
FLAGS['model'] = xmp.MpModelWrapper(ClassificationModel(n_classes=4))

# Dataset

In [None]:
CLASS_NAMES  = {'IS': 0, 'ISLG': 1, 'LG': 2, 'LGIS': 3}

class StairNetDataset(Dataset):
    def __init__(self, folder, scale_img=True):
        self.folder_name = folder
        self.folder = os.listdir(folder)
        self.scale = scale_img

    def process_label(self, label):
        ''' read label from filename'''
        res = label.split('_')[-1].split('.')[0]
        return res

    def _create_one_hot(self, label):
        ''' creating one-hot label from ordinary label '''
        one_hot = np.zeros(shape=(len(CLASS_NAMES)))
        one_hot[CLASS_NAMES[label]] = 1
        return one_hot

    def __getitem__(self, idx):
        labels = self.folder[idx]
        sample = torch.from_numpy(
            np.load(os.path.join(self.folder_name, labels)).astype('float32'))
        label = self.process_label(labels)
        if self.scale:
            sample = sample.div(255.)
        return sample, self._create_one_hot(label)

    def __len__(self):
        return len(self.folder)

In [None]:
FLAGS['train_dataset'] = StairNetDataset(FLAGS['train_dataset_path'])
FLAGS['val_dataset'] = StairNetDataset(FLAGS['val_dataset_path'])
FLAGS['test_dataset'] = StairNetDataset(FLAGS['test_dataset_path'])

# Visualization of a sample

In [None]:
img, label = FLAGS['val_dataset'][25]

fig = plt.figure(figsize=(16, 8))
for i in range(1, 5*1 +1):
    fig.add_subplot(1, 5, i)
    plt.imshow(img[i - 1])
plt.show()

In [None]:
def reduce(values):
    '''    
    Returns the average of the values.
    Args:
        values : list of any value which is calulated on each core 
    '''
    return sum(values) / len(values)

# Train and Evaluation loops

In [None]:
def train_epoch(train_dataloader, model, criterion, optimizer, device, scheduler=None, MAX_GRAD_NORM=10):
    model.train()

    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []

    for idx, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc='Train step'):
        print(xm.get_ordinal(), device)
        seq_batch = batch[0].to(device)
        labels = batch[1].to(device)
        #xm.master_print('Loaded sample')

        output = model(seq_batch)
        loss = criterion(labels, output)
        #xm.master_print('Processed sample')

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        #if idx % 50 == 0:
        #    loss_step = tr_loss/nb_tr_steps
        #    print(f"Training loss per {idx} training steps: {loss_step}")
           
        # compute training accuracy
        targets = torch.argmax(labels, dim=1) # shape (batch_size)
        predictions = torch.argmax(output, dim=1) # shape (batch_size)
        
        tr_labels.extend(targets)
        tr_preds.extend(predictions)

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=MAX_GRAD_NORM
        )

        tmp_tr_accuracy = accuracy_score(targets.cpu().numpy(), predictions.cpu().numpy())
        tr_accuracy += tmp_tr_accuracy
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        xm.optimizer_step(optimizer)
        if scheduler is not None:
          scheduler.step()

        loss_reduced = xm.mesh_reduce('train_loss_reduce', 
                                      loss, 
                                      lambda vals: sum(vals) / len(vals))
        xm.master_print(f'Train loss: {tr_loss} Added loss: {loss_reduced.item()}') 
        tr_loss += loss_reduced.item()

    xm.master_print(f'Loss: {tr_loss} Nb steps: {nb_tr_steps}')
    epoch_loss = tr_loss / nb_tr_steps
    xm.master_print(f'Epoch loss: {epoch_loss}')
    tr_accuracy = tr_accuracy / nb_tr_steps
    xm.master_print(f"Training loss epoch: {epoch_loss}")
    xm.master_print(f"Training accuracy epoch: {tr_accuracy}")
    return epoch_loss, tr_accuracy

def evaluate_epoch(test_dataloader, model, criterion, device, metric=False):
    # put model in evaluation mode
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels = [], []
    
    with torch.no_grad():
        for idx, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Evaluation step'):
            seq_batch = batch[0].to(device)
            labels = batch[1].to(device)
            
            outputs = model(seq_batch)
            loss = criterion(labels, outputs)
            
            loss_reduced = xm.mesh_reduce('eval_loss_reduce', 
                                      loss, 
                                      lambda vals: sum(vals) / len(vals))
            
            eval_loss += loss_reduced.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)
        
            #if idx % 50 == 0:
            #    loss_step = eval_loss/nb_eval_steps
            #    print(f"Validation loss per {idx} evaluation steps: {loss_step}")
              
            # compute evaluation accuracy
            targets = torch.argmax(labels, dim=1) # shape (batch_size)
            predictions = torch.argmax(outputs, dim=1) # shape (batch_size)
            
            eval_labels.extend(targets)
            eval_preds.extend(predictions)
            
            tmp_eval_accuracy = accuracy_score(targets.cpu().numpy(), predictions.cpu().numpy())
            eval_accuracy += tmp_eval_accuracy

    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_steps
    xm.master_print(f"Validation Loss: {eval_loss}")
    xm.master_print(f"Validation Accuracy: {eval_accuracy}")
    return eval_loss, eval_accuracy

In [None]:
# Start training processes
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    device = xm.xla_device()

    # getting distrbuted train sampler
    train_sampler = DistributedSampler(
        dataset=flags['train_dataset'],
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    # distributed train dataloader
    train_dataloader = DataLoader(
        dataset=flags['train_dataset'],
        batch_size=flags['batch_size'],
        sampler=train_sampler,
        num_workers=flags['num_workers'],
        drop_last=True    
    )
    
    # getting distributed val sampler
    val_sampler = DistributedSampler(
        dataset=flags['val_dataset'],
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    # distrbuted val dataloader
    val_dataloader = DataLoader(
        dataset=flags['val_dataset'],
        batch_size=flags['batch_size'],
        sampler=train_sampler,
        num_workers=flags['num_workers'],
        drop_last=True
    )
  
    del train_sampler, val_sampler 
    gc.collect()

    model = FLAGS['model'].to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam( #torch.optim.SGD(
        model.parameters(), 
        lr=FLAGS['learning_rate'],
        weight_decay=0.1
    )
    xm.master_print('Train start...')

    for epoch in range(flags['num_epochs']):
        # train step 
        train_parallel_loader = pl.ParallelLoader(
            train_dataloader, [device]).per_device_loader(device)
        train_loss, train_acc = train_epoch(
            train_dataloader=train_parallel_loader, 
            model=model, 
            criterion=criterion, 
            optimizer=optimizer, 
            device=device
        )
        xm.master_print(f'Train loss: {train_loss} \t Train acc: {train_acc}')

        del train_parallel_loader
        gc.collect()

        xm.master_print('Eval step...')
        # evaluation step
        val_parallel_loader = pl.ParallelLoader(
            val_dataloader, [device]).per_device_loader(device)
        val_loss, val_acc = evaluate_epoch(
            test_dataloader=val_parallel_loader, 
            model=model, 
            criterion=criterion, 
            device=device, 
        )

        del val_parallel_loader
        gc.collect()

        xm.save(model.state_dict(), f'model_weights.pth')


In [None]:
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
          start_method='fork')