# Shuffle and learn

Train encoder to predict if 3 frames are in correct temporal order or not.  
Paper: https://arxiv.org/abs/1603.08561

In [1]:
import cv2

import sys

sys.path.append('/scratch/mz2476/DL/project/')

import os
import random

import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200
from ssl_project.data_loaders import plot_utils

import imageio

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from ssl_project.data_loaders.data_helper import UnlabeledDataset, LabeledDataset
from ssl_project.data_loaders.helper import collate_fn, draw_box
from ssl_project import constants

from ssl_project.preprocessing import top_down_segmentation


from ssl_project.utils import to_np

In [2]:
torch.cuda.is_available()

True

In [3]:
from ssl_project.constants import *
from ssl_project.paths import *

In [4]:
from logger_hparams import HyperparamsSummaryTensorBoardLogger

from ssl_project.ssl_ideas.preprocessing import TripleDataset
from ssl_project.ssl_ideas.model import SET_SEED, ShuffleAndLearnModel

In [5]:
import pytorch_lightning as pl 
from argparse import Namespace

# Train model

In [9]:
SET_SEED()

In [10]:
LOGS_DIR   = "lightning_logs"
hparams = Namespace(
    fit_all_encoders=False,
    lr=3e-3,
    num_workers=8,
    batch_size=64,
)

MODEL_NAME = f"first_try_many_encoders={hparams.fit_all_encoders}"
version = "04"

In [11]:
logger = HyperparamsSummaryTensorBoardLogger(LOGS_DIR, name=MODEL_NAME, version=version)
model  = ShuffleAndLearnModel(hparams=hparams)

In [12]:
list(name for name, value in model.model.named_children())

['resnet_encoder', 'decoder', 'clf']

In [13]:
model.parameters()

<generator object Module.parameters at 0x2b1bf10e3740>

In [14]:
trainer = pl.Trainer(
    gpus=[0], 
#     auto_lr_find=True,
    show_progress_bar=True,
    val_check_interval=0.25,
#     train_percent_check=0.001,
#     val_percent_check=0.01,
    logger=logger,
    precision=32,
)

INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(model, )

INFO:lightning:Set SLURM handle signals.
INFO:lightning:
   | Name                                                 | Type               | Params
----------------------------------------------------------------------------------------
0  | model                                                | ShuffleAndLearnNet | 18 M  
1  | model.resnet_encoder                                 | encoder            | 11 M  
2  | model.resnet_encoder.resnet_encoder                  | Sequential         | 11 M  
3  | model.resnet_encoder.resnet_encoder.0                | Conv2d             | 9 K   
4  | model.resnet_encoder.resnet_encoder.1                | BatchNorm2d        | 128   
5  | model.resnet_encoder.resnet_encoder.2                | ReLU               | 0     
6  | model.resnet_encoder.resnet_encoder.3                | MaxPool2d          | 0     
7  | model.resnet_encoder.resnet_encoder.4                | Sequential         | 147 K 
8  | model.resnet_encoder.resnet_encoder.4.0              | Ba

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…