# Tutorial: Model Training / Inference based on local/mixed Datasets

# 0 Configure ENVS

In [None]:
!git clone https://github.com/rwightman/pytorch-image-models.git

%cd pytorch-image-models/
!python -m pip --no-cache-dir install -r requirements.txt
!python setup.py develop

# 1 MixedDataset: Querying, Training and Testing

## 1.1 Training-Perpartion from local datasets

### 1.1.1 Download ImageNette

In [None]:
!wget -P ./data/ https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz
!tar -xf ./data/imagenette2-320.tgz -C ./data/ && rm ./data/imagenette2-320.tgz

### 1.1.2 Download CIFAR-100

In [None]:
!wget -P ./data/ https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
!tar -xf ./data/cifar-100-python.tar.gz -C ./data/ && rm ./data/cifar-100-python.tar.gz

In [None]:
import os
from PIL import Image
import numpy as np
from tqdm import trange


cifar100_python='/content/pytorch-image-models/data/cifar-100-python'
cifar100_images='./data/cifar100_images'

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict_data = pickle.load(fo, encoding='latin1')
    return dict_data

def save_dirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

meta_dict=unpickle(os.path.join(cifar100_python,'meta'))
save_dirs(cifar100_images)

for data_set in ['train', 'val']:
        print('Unpickling {} dataset......'.format(data_set))
        for idx, fine_label_name in enumerate(meta_dict['fine_label_names']):
            save_dirs(os.path.join(cifar100_images, data_set, fine_label_name))
        if data_set == 'val':
            data_dict = unpickle(os.path.join(cifar100_python, 'test'))
        else:
            data_dict = unpickle(os.path.join(cifar100_python, data_set))
        data, label = np.array(data_dict['data']).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), data_dict['fine_labels']
        for i in trange(data.shape[0]):
            img = Image.fromarray(data[i])
            img.save(os.path.join(cifar100_images,data_set, meta_dict['fine_label_names'][data_dict['fine_labels'][i]], data_dict['filenames'][i]))

In [None]:
import os
import matplotlib.pyplot as plt
import copy
import numpy as np

def plot_classes_number(path_to_images):
    classes_dir = os.listdir(path_to_images)
    all_classes_number = len(classes_dir)
    class2number = {}
    for class_dir in classes_dir:
        class_path =os.path.join(path_to_images, class_dir)
        picture_name_sequence =os.listdir(class_path)
        each_class_number = len(picture_name_sequence)
        class2number[class_dir] = each_class_number
    list_classname = []
    list_classnumber = []
    for class_dir in classes_dir:
        list_classname.append(class_dir)
        list_classnumber.append(class2number[class_dir])
    plt.figure(figsize=(22, 8), dpi=300)
    width = 1.4
    x = np.arange(len(list_classname))
    y = np.array(list_classnumber)
    plt.bar(x, y, width, align='center')
    plt.ylabel("Number per Class")
    plt.xlabel("Class name")
    plt.title("Number of each Class")
    plt.xticks(x, list_classname, size='small', rotation=90)
    plt.show()

## 1.2 Training-Preparation from a queried mixed dataset based on VisionKG

In [None]:
# install our vision utils
!python -m pip install git+https://github.com/cqels/vision.git --force

In [None]:
import json
import os
from os.path import join as opj
from torch_model_zoo.utils import dataset_split, check_instances_categories, check_download_images
import json
from vision_utils import semkg_api, data

query_string='''#Give me the images containing bus and pickup-truck
PREFIX cv: <http://vision.semkg.org/onto/v0.1/>
PREFIX schema: <http://schema.org/>

SELECT DISTINCT ?image ?label ?imageName ?labelResource
WHERE {
    {
        SELECT DISTINCT ?image ?label ?imageName ?labelResource
        WHERE {
            ?image schema:isPartOf / schema:name ?datasetName .
            ?image cv:hasAnnotation ?annotation .
            ?image schema:name ?imageName.
            ?annotation a cv:ClassificationAnnotation.
            ?annotation cv:hasLabel ?labelResource .
            ?labelResource cv:label "bus" .
        }
        LIMIT 50
    }
    UNION
    {
        SELECT DISTINCT ?image ?label ?imageName ?labelResource
        WHERE {
            ?image schema:isPartOf / schema:name ?datasetName .
            ?image cv:hasAnnotation ?annotation .
            ?image schema:name ?imageName.
            ?annotation a cv:ClassificationAnnotation.
            ?annotation cv:hasLabel ?labelResource .
            ?labelResource cv:label "pickup_truck" .
        }
        LIMIT 50
    }
}
'''

result=semkg_api.query(query_string)
ROOT_PATH = os.path.abspath('./')
json_f_name = 'test_query_api_image.json'
path_to_anno_mixedDatasets = opj(ROOT_PATH, 'testData/mixedDatasets/vkg/meta/')
path_to_images_mixedDatasets = opj(ROOT_PATH, 'testData/mixedDatasets/vkg/')
os.makedirs(path_to_anno_mixedDatasets, exist_ok=True)
path_to_anno = opj(path_to_anno_mixedDatasets, json_f_name)

with open(path_to_anno, "w") as f:
    json.dump(result,f)
query_results = {'images': ['https://vision-api.semkg.org/api/view?image=/' + i['labelResource'].split(os.sep)[-2] + os.sep + i['imageName'] for i in result],
                 'categories': [i['labelResource'].split(os.sep)[-1] for i in result]}

In [None]:
from sklearn.model_selection import train_test_split
import os
import urllib.request
import time

def create_classmap(categories):
    unique_categories = sorted(set(categories))
    return {category: i for i, category in enumerate(unique_categories)}

def pull_vkg_images(images, categories, base_folder, classmap):
    save_dir = os.path.dirname(base_folder)
    txt_filename = os.path.join(save_dir, 'train.txt' if 'train' in base_folder else 'val.txt')
    with open(txt_filename, 'w') as txt_file:
        for url, category in zip(images, categories):
            folder = os.path.join(base_folder, category)
            if not os.path.exists(folder):
                os.makedirs(folder)
            filename = os.path.join(folder, url.split('/')[-1])
            time.sleep(1)
            urllib.request.urlretrieve(url, filename)
            txt_file.write(f"{os.path.join('train' if 'train' in base_folder else 'val', category, url.split('/')[-1])} {classmap[category]}\n")

train_images, test_images, train_categories, test_categories = train_test_split(
    query_results['images'], query_results['categories'], test_size=0.2, stratify=query_results['categories'])
classmap = create_classmap(query_results['categories'])
pull_vkg_images(train_images, train_categories, opj(path_to_images_mixedDatasets, 'train'), classmap)
pull_vkg_images(test_images, test_categories, opj(path_to_images_mixedDatasets, 'val'), classmap)
with open(os.path.join(path_to_images_mixedDatasets, 'classmap.txt'), 'w') as map_file:
    for category, num in classmap.items():
        map_file.write(f"{category} {num}\n")

## 1.3 Set dirs && other params

In [None]:
import os

data_dir = path_to_images_mixedDatasets
test_images = opj(path_to_images_mixedDatasets, 'val')
output_training_dir = './output/train/'
output_test_dir = './output/test/'
num_classes = len(os.listdir(test_images))
used_model = 'resnet18'

## 1.4 Training on the queried mixed dataset or local data

In [None]:
%run train.py {data_dir} \
--data-dir {data_dir} \
--output {output_training_dir} \
--model {used_model} \
--sched 'cosine' \
--epochs=50 \
--color-jitter=0 \
--num-classes {num_classes} \
--amp \
--lr=1e-4 \
--min-lr=1e-8 --warmup-epochs=3 --train-interpolation=bilinear --aa=v0 \
--checkpoint-hist 1 \
--pretrained \
--opt=adamw \
--weight-decay=1e-4 \
--batch-size 16

## 1.5 Verify the checkpoint file.

In [None]:
all_existed_dirs = [os.path.join(output_training_dir,sub_dir) for sub_dir in os.listdir(output_training_dir) if os.path.isdir(os.path.join(output_training_dir,sub_dir))]
latest_output_dir = max(all_existed_dirs, key=os.path.getmtime)
checkpoint_file = os.path.join(latest_output_dir, "model_best.pth.tar")
assert os.path.isfile(checkpoint_file), '{} not exist'.format(checkpoint_file)
checkpoint_file = os.path.abspath(checkpoint_file)

## 1.6 Testing on the mixed dataset or local data

In [None]:
%run validate.py {test_images} \
--model {used_model} \
--num-classes {num_classes} \
--checkpoint {checkpoint_file}