In [5]:
from PIL import Image
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import numpy as np
import pandas as pd
import pickle as pkl
from sklearn.metrics import jaccard_score as IOU
from torchvision import models, transforms, io
from torch.utils.data import Dataset, DataLoader
import torch.nn.utils.prune as prune
import utils
import os
import time
import copy

# Data Loading

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DATASET_PATH = 'ADE20K_2021_17_01/'
index_file = 'index_ade20k.pkl'
with open('{}/{}'.format(DATASET_PATH, index_file), 'rb') as f:
    index_ade20k = pkl.load(f)

objects_mat = index_ade20k['objectPresence']

# Find 150 most common object IDs and non-common object IDs
total_object_counts = np.sum(objects_mat, axis=1)
object_count_ids = np.argsort(total_object_counts)[::-1]
most_common_obj_ids = object_count_ids[:150]
irrelevant_obj_ids = object_count_ids[150:]
# Find image IDs where no irrelevant objects appear
irrelevant_obj_counts = np.sum(objects_mat[irrelevant_obj_ids], axis=0)
good_image_ids = np.argwhere(irrelevant_obj_counts == 0).flatten()
# Only common objects included
common_objects_mat = objects_mat[np.ix_(most_common_obj_ids, good_image_ids)]

# Maps {obj_ids: 0-149}
obj_id_map = {sorted(most_common_obj_ids)[idx]: idx + 1 for idx in range(150)}
obj_id_map[-1] = 0

# Pick out images to train/evaluate on
train_image_ids = []
test_image_ids = []
for i in good_image_ids:
    if 'training' in index_ade20k['folder'][i]:
        train_image_ids.append(i)
    elif 'validation' in index_ade20k['folder'][i]:
        test_image_ids.append(i)
    else:
        raise Exception('Invalid folder name.')

In [7]:
class SegmentationDataset(Dataset):
    def __init__(self, image_ids, root_dir, index_mat, transform=None, target_transform=None):
        """
        Args:
            image_ids (list): list of image IDs from ADE20K
            root_dir (string): Directory with all the images.
            index_mat (array): object array from index_ade20k.pkl
            transform (callable, optional): Optional transform to be applied
                on a sample.
            target_transform (callable, optional): Optional transform to be applied
                on a sample segmentation label.
        """
        self.image_ids = image_ids
        self.root_dir = root_dir
        self.index_ade20k = index_mat
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_name = os.path.join(self.root_dir, self.index_ade20k['folder'][image_id], 
                                self.index_ade20k['filename'][image_id])
        img_info = utils.loadAde20K(img_name)
        
        image = io.read_image(img_info['img_name']).float()
        class_mask = Image.fromarray(img_info['class_mask'], mode='I')
        
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(class_mask)
            
        sample = (image, label)

        return sample

In [8]:
input_size = 224
transform = transforms.Compose([
                transforms.Resize(input_size),
                transforms.CenterCrop(input_size),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

target_transform = transforms.Compose([
                transforms.Resize(input_size, interpolation=0),
                transforms.CenterCrop(input_size),
                transforms.ToTensor()
            ])

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [9]:
num_samples = 4
batch_size = 2
training_data = SegmentationDataset(train_image_ids[:num_samples], './', index_ade20k, transform=transform, target_transform=target_transform)
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=False)
testing_data = SegmentationDataset(test_image_ids[:num_samples], './', index_ade20k, transform=transform, target_transform=target_transform)
test_dataloader = DataLoader(testing_data, batch_size=batch_size, shuffle=False)

# Load Pre-trained Model

In [10]:
def get_parameter_size(model):
    """
    Return model size in terms of parameters
    Each parameter is a float32 - 4 bytes
    """
    num_params = 0
    for p in model.parameters():
        num_params += torch.count_nonzero(p.flatten())
        
    total_bytes = num_params.item() / 4
    kb = total_bytes / 1000
    
    return {"# Params": num_params.item(),
            "Size in KB": kb}

In [14]:
model = models.segmentation.fcn_resnet50(pretrained=False, num_classes=151).to(device=device)
model.load_state_dict(torch.load('../epochs_20_weights.pkl', map_location=torch.device('cpu')))

<All keys matched successfully>

In [15]:
print(get_parameter_size(model))

{'# Params': 33023703, 'Size in KB': 8255.92575}


### Train original model

In [17]:
def encode_label(label_arr, obj_id_map):
    """
    Encode labels for evaluating loss
    label_arr (tensor): B x 1 x H x W
    """
    convert_label_ids = lambda i: obj_id_map[i-1]
    vect_convert_label_ids = np.vectorize(convert_label_ids)
    
    encoded_label = vect_convert_label_ids(label_arr.squeeze().numpy())
    
    return torch.tensor(encoded_label, dtype=torch.long)

In [20]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
epochs = 2
load_data_start = time.time()
for i in range(epochs):
    print('###### Epoch {} ######'.format(i+1))
    epoch_start = time.time()

    # training pass
    running_loss = 0
    batch_num = 0
    for images, labels in train_dataloader:
        batch_start = time.time()
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        output = model(images)['out']
        labels = encode_label(labels, obj_id_map).to(device)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        batch_num += 1
        print('Batch {}/{} finished... {} seconds'.format(batch_num,len(train_dataloader), time.time() - batch_start))
    print('-----> Training loss: {}'.format(running_loss/len(train_dataloader)))
    print("Training time: {} seconds".format(time.time() - epoch_start))
#     torch.save(model.state_dict(), result_path+'/epochs_{}_weights.pkl'.format(i+1))

    # testing pass
    test_start = time.time()
    running_accuracy = 0
    running_iou = 0
    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(device)
            output = model(images)['out']
            labels = encode_label(labels, obj_id_map).to(device)
            probs = torch.nn.functional.softmax(output, dim=1)
            preds = torch.argmax(probs, dim=1, keepdim=True)
            num_correct = torch.sum((preds == labels).to(int)).item()
            acc = num_correct/(input_size*input_size*len(images))
            running_accuracy += acc
            print('Testing pixel accuracy: {}'.format(num_correct/(input_size*input_size*len(images))))
            iou = IOU(labels.detach().numpy().reshape(-1), preds.detach().numpy().reshape(-1))
            running_iou += iou
            print('Testing IOU score: {}'.format(iou))
        print('-----> Overall testing pixel accuracy: {}'.format(running_accuracy / len(test_dataloader)))
        print('-----> Overall testing IOU accuracy: {}'.format(running_iou / len(test_dataloader)))
    print("Testing time: {} seconds".format(time.time() - test_start))

    print("Epoch completed in {} seconds.".format(time.time() - epoch_start))

print('\n' + '#'*100)
print("DONE TRAINING in {} seconds.".format(time.time() - load_data_start))
print('#'*100 + '\n')

###### Epoch 1 ######
Batch 1/2 finished... 3.6026101112365723 seconds
Batch 2/2 finished... 3.639173746109009 seconds
-----> Training loss: 1.014138251543045
Training time: 7.943634986877441 seconds
Testing pixel accuracy: 0.5658581792091837
Testing IOU score: 0.2772640306122449
Testing pixel accuracy: 0.4554866868622449
Testing IOU score: 0.23334861288265307
-----> Overall testing pixel accuracy: 0.5106724330357143
-----> Overall testing IOU accuracy: 0.255306321747449
Testing time: 2.5647149085998535 seconds
Epoch completed in 10.508388996124268 seconds.
###### Epoch 2 ######
Batch 1/2 finished... 3.8629257678985596 seconds
Batch 2/2 finished... 3.6279048919677734 seconds
-----> Training loss: 0.8560846149921417
Training time: 8.167415142059326 seconds
Testing pixel accuracy: 0.5574477838010204
Testing IOU score: 0.27103595344387754
Testing pixel accuracy: 0.4573301977040816
Testing IOU score: 0.23524194834183673
-----> Overall testing pixel accuracy: 0.5073889907525511
-----> Overa

### Prune low weights

In [16]:
class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    def __init__(self, threshold):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) > self.threshold

In [44]:
model_copy = copy.deepcopy(model)
thresh = 0.0025
params_to_prune = [(module, "weight") for _, module in model_copy.named_modules() if isinstance(module, torch.nn.Conv2d)]

prune.global_unstructured(params_to_prune, pruning_method=ThresholdPruning, threshold=thresh)

for _, module in model_copy.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.remove(module, 'weight')

In [45]:
print(get_parameter_size(model))
print(get_parameter_size(model_copy))

{'# Params': 33023703, 'Size in KB': 8255.92575}
{'# Params': 26041274, 'Size in KB': 6510.3185}


### Compression Sizes
Original: {'# Params': 33023703, 'Size in KB': 8255.92575}

Threshold: 0.1 ---
{'# Params': 76445, 'Size in KB': 19.11125} --- 99.8% compression

Threshold: 0.025 ---
{'# Params': 2211294, 'Size in KB': 552.8235} --- 93.3% compression

Threshold: 0.01 ---
{'# Params': 10571276, 'Size in KB': 2642.819} --- 68.0% compression

Threshold: 0.0025 --- 
{'# Params': 26041274, 'Size in KB': 6510.3185} --- 21.1% compression

Threshold: 0.001 ---
{'# Params': 30165420, 'Size in KB': 7541.355} --- 8.66% compression

In [37]:
matching_flag = True
for p1, p2 in zip(model.parameters(), model_copy.parameters()):
    if p1.data.ne(p2.data).sum() > 0:
        matching_flag = False
print("Copied weights" if matching_flag else "Not matching weights")

Not matching weights
