<a href="https://colab.research.google.com/github/jmarrietar/ocular/blob/master/notebooks/PAWS_DR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Mount Drive
from google.colab import drive
from google.colab import auth
from googleapiclient.http import MediaFileUpload
from googleapiclient.discovery import build

In [3]:
drive.mount('/content/drive')
auth.authenticate_user()

Mounted at /content/drive


In [4]:
!pip install colab_ssh --upgrade
from colab_ssh import launch_ssh

Collecting colab_ssh
  Downloading https://files.pythonhosted.org/packages/74/c0/a2d6cc985d9496968b80203105f20c8d3845effa45dd4cb46c22879f1e44/colab_ssh-0.3.15-py3-none-any.whl
Installing collected packages: colab-ssh
Successfully installed colab-ssh-0.3.15


In [5]:
# Install colab_ssh on google colab
!pip install colab_ssh --upgrade

from colab_ssh import launch_ssh_cloudflared, init_git_cloudflared
launch_ssh_cloudflared(password="123456")

Requirement already up-to-date: colab_ssh in /usr/local/lib/python3.7/dist-packages (0.3.15)


In [6]:
init_git_cloudflared("https://github.com/jmarrietar/suncet.git",
         personal_token="", 
         branch="feature/DR-images",
         email="jmarrietar@gmail.com",
         username="jmarrietar")

In [None]:
#########################################

In [9]:
pwd

'/content/suncet'

In [12]:
import argparse

import torch.multiprocessing as mp

import pprint
import yaml

In [8]:
cd suncet/

/content/suncet


In [14]:
parser = argparse.ArgumentParser()
parser.add_argument(
    '--fname', type=str,
    help='name of config file to load',
    default='configs.yaml')
parser.add_argument(
    '--devices', type=str, nargs='+', default=['cuda:0'],
    help='which devices to use on local machine')
parser.add_argument(
    '--sel', type=str,
    help='which script to run',
    choices=[
        'paws_train',
        'suncet_train',
        'fine_tune',
        'snn_fine_tune'
    ])

_StoreAction(option_strings=['--sel'], dest='sel', nargs=None, const=None, default=None, type=<class 'str'>, choices=['paws_train', 'suncet_train', 'fine_tune', 'snn_fine_tune'], help='which script to run', metavar=None)

In [15]:
args = parser.parse_args(['--sel', 'paws_train',
                            '--fname', 'configs/paws/dr_train.yaml'
])

In [37]:
import logging
import sys
from collections import OrderedDict

import numpy as np

import torch

import src.resnet as resnet
import src.wide_resnet as wide_resnet
from src.utils import (
    gpu_timer,
    init_distributed,
    WarmupCosineSchedule,
    CSVLogger,
    AverageMeter
)
from src.losses import (
    init_paws_loss,
    make_labels_matrix
)
from src.data_manager import (
    init_data,
    make_transforms,
    make_multicrop_transform
)
from src.sgd import SGD
from src.lars import LARS

import apex
from torch.nn.parallel import DistributedDataParallel

In [38]:

import logging
logging.basicConfig()
logger = logging.getLogger()

In [39]:
fname = 'configs/paws/dr_train.yaml'

In [40]:
# -- load script params
params = None
with open(fname, 'r') as y_file:
    params = yaml.load(y_file, Loader=yaml.FullLoader)
    logger.info('loaded params...')
    
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(params)

{   'criterion': {   'classes_per_batch': 15,
                     'me_max': True,
                     'sharpen': 0.25,
                     'supervised_imgs_per_class': 7,
                     'supervised_views': 1,
                     'temperature': 0.1,
                     'unsupervised_batch_size': 64},
    'data': {   'color_jitter_strength': 1.0,
                'data_seed': None,
                'dataset': 'dr',
                'image_folder': 'dr/sample@1000/',
                'label_smoothing': 0.1,
                'multicrop': 6,
                'normalize': True,
                'root_path': 'datasets/',
                'subset_path': 'dr_subsets',
                'unique_classes_per_rank': True,
                'unlabeled_frac': 0.9},
    'logging': {   'folder': '/content/drive/MyDrive/Colab '
                             'Notebooks/PAWS/logs/',
                   'write_tag': 'paws'},
    'meta': {   'copy_data': True,
                'device': 'cuda:0',
              

In [41]:
args = params

In [42]:
    # ----------------------------------------------------------------------- #
    #  PASSED IN PARAMS FROM CONFIG FILE
    # ----------------------------------------------------------------------- #
    # -- META
    model_name = args['meta']['model_name']
    output_dim = args['meta']['output_dim']
    load_model = args['meta']['load_checkpoint']
    r_file = args['meta']['read_checkpoint']
    copy_data = args['meta']['copy_data']
    use_fp16 = args['meta']['use_fp16']
    use_pred_head = args['meta']['use_pred_head']
    device = torch.device(args['meta']['device'])
    torch.cuda.set_device(device)

    # -- CRITERTION
    reg = args['criterion']['me_max']
    supervised_views = args['criterion']['supervised_views']
    classes_per_batch = args['criterion']['classes_per_batch']
    s_batch_size = args['criterion']['supervised_imgs_per_class']
    u_batch_size = args['criterion']['unsupervised_batch_size']
    temperature = args['criterion']['temperature']
    sharpen = args['criterion']['sharpen']

    # -- DATA
    unlabeled_frac = args['data']['unlabeled_frac']
    color_jitter = args['data']['color_jitter_strength']
    normalize = args['data']['normalize']
    root_path = args['data']['root_path']
    image_folder = args['data']['image_folder']
    dataset_name = args['data']['dataset']
    subset_path = args['data']['subset_path']
    unique_classes = args['data']['unique_classes_per_rank']
    multicrop = args['data']['multicrop']
    label_smoothing = args['data']['label_smoothing']
    data_seed = None
    if 'cifar10' in dataset_name:
        data_seed = args['data']['data_seed']
        crop_scale = (0.75, 1.0) if multicrop > 0 else (0.5, 1.0)
        mc_scale = (0.3, 0.75)
        mc_size = 18
    else:
        crop_scale = (0.14, 1.0) if multicrop > 0 else (0.08, 1.0)
        mc_scale = (0.05, 0.14)
        mc_size = 96

    # -- OPTIMIZATION
    wd = float(args['optimization']['weight_decay'])
    num_epochs = args['optimization']['epochs']
    warmup = args['optimization']['warmup']
    start_lr = args['optimization']['start_lr']
    lr = args['optimization']['lr']
    final_lr = args['optimization']['final_lr']
    mom = args['optimization']['momentum']
    nesterov = args['optimization']['nesterov']

    # -- LOGGING
    folder = args['logging']['folder']
    tag = args['logging']['write_tag']
    # ----------------------------------------------------------------------- #

    # -- init torch distributed backend
    world_size, rank = init_distributed()
    logger.info(f'Initialized (rank/world-size) {rank}/{world_size}')

    # -- log/checkpointing paths
    log_file = os.path.join(folder, f'{tag}_r{rank}.csv')
    save_path = os.path.join(folder, f'{tag}' + '-ep{epoch}.pth.tar')
    latest_path = os.path.join(folder, f'{tag}-latest.pth.tar')
    best_path = os.path.join(folder, f'{tag}' + '-best.pth.tar')
    load_path = None
    if load_model:
        load_path = os.path.join(folder, r_file) if r_file is not None else latest_path

    # -- make csv_logger
    csv_logger = CSVLogger(log_file,
                           ('%d', 'epoch'),
                           ('%d', 'itr'),
                           ('%.5f', 'paws-xent-loss'),
                           ('%.5f', 'paws-me_max-reg'),
                           ('%d', 'time (ms)'))



    print("normalize {}".format(normalize))

normalize True


In [43]:
    # -- make data transforms
    transform, init_transform = make_transforms(
        dataset_name=dataset_name,
        subset_path=subset_path,
        unlabeled_frac=unlabeled_frac,
        training=True,
        split_seed=data_seed,
        crop_scale=crop_scale,
        basic_augmentations=False,
        color_jitter=color_jitter,
        normalize=normalize)
    multicrop_transform = (multicrop, None)
    if multicrop > 0:
        multicrop_transform = make_multicrop_transform(
                dataset_name=dataset_name,
                num_crops=multicrop,
                size=mc_size,
                crop_scale=mc_scale,
                normalize=normalize,
                color_distortion=color_jitter)

In [44]:
(unsupervised_loader,
    unsupervised_sampler,
    supervised_loader,
    supervised_sampler) = init_data(
        dataset_name=dataset_name,
        transform=transform,
        init_transform=init_transform,
        supervised_views=supervised_views,
        u_batch_size=u_batch_size,
        s_batch_size=s_batch_size,
        unique_classes=unique_classes,
        classes_per_batch=classes_per_batch,
        multicrop_transform=multicrop_transform,
        world_size=world_size,
        rank=rank,
        root_path=root_path,
        image_folder=image_folder,
        training=True,
        copy_data=copy_data)

  cpuset_checked))


In [45]:
unsupervised_loader

<torch.utils.data.dataloader.DataLoader at 0x7f7a0329ab50>

In [46]:
c = 0 
for itr, udata in enumerate(unsupervised_loader):
    c = c +1
    if c>5:
        break

  cpuset_checked))


TypeError: ignored