# RotNet on Semi-supervised of 10% Cifar10 Labeled Data

- [Unsupervised Representation Learning by Predicting Image Rotations](https://arxiv.org/abs/1803.07728)

### Default Hyper-Parameters
- Resnet18-preact
- Cosine Annealing Learning Rate with initial value = 1e-3*2
- Weight Decay on AdamW = 1e-4
- Training Epochs = 300
- Mixed Precision Training + XLA (overall, speedup training more than 2x)
- Batchsize = 128*2 (due to mixed precision)

### Settings

#### Semi-Supervised

1. pretrain on the whole Cifar10 dataset by 4 rotation labels
2. with fixed a few low-lovel layers, fine tuned with 10% labels from Cifar10

#### Unsupervised

1. pretrain on the whole Cifar10 dataset by 4 rotation labels
2. with the whole pretrained network, append it with one linear FC layer, then fine tuned with the whole Cifar10 dataset

### Results Benchmark

|Exp_Name|Accuracy|Descrition|
|:-:|:-:|:-:|
|Supervised_FullLabel|0.95000|Upper Bound|
|PretrainRotNet|0.92250|Pretrained Model|
|**Semi_on_PretrainRotNet_1**|**0.85380**|**Best, Fixed 1 Conv**|
|Semi_on_PretrainRotNet_2|0.85010|Fixed 5 Conv|
|Semi_on_PretrainRotNet_3|0.84310|Fixed 9 Conv|
|Semi_on_PretrainRotNet_4|0.81930|Fixed 13 Conv|
|**Supervised_LowLabel**|**0.80290**|**Baseline**|
|Unsupervised_on_PretrainRotNet|0.65610|Fixed 17 Conv, wd=0.0|
|Unsupervised_on_PretrainRotNet|0.64600|Fixed 17 Conv, wd=1e-5|
|Unsupervised_on_PretrainRotNet|0.59100|Fixed 17 Conv|
|Semi_on_PretrainRotNet_5|0.54780|Fixed 17 Conv|

In [None]:
def f():
    import sys
    sys.path.append('../..')
    import tf_rlib
    from tf_rlib.runners.research import RotNetRunner, SemiRotNetRunner
    FLAGS = tf_rlib.FLAGS

#     tf_rlib.utils.purge_logs()
    tf_rlib.utils.set_gpus('0')
    tf_rlib.utils.set_logging('WARN')

    FLAGS.exp_name='PretrainRotNet'

    FLAGS.xla=True
    FLAGS.amp=True

    FLAGS.bs=128
    FLAGS.dim=2
    FLAGS.out_dim=4
    FLAGS.lr=1e-3
    FLAGS.wd=1e-4
    FLAGS.epochs=300

    if FLAGS.amp:
        FLAGS.lr = FLAGS.lr*2
        FLAGS.bs = FLAGS.bs*2

    datasets = tf_rlib.datasets.Cifar10Rotate().get_data()
    runner = RotNetRunner(*datasets)
    runner.fit(FLAGS.epochs, lr=FLAGS.lr)
    paths = runner.get_saved_models_path()

from multiprocessing import Process
p = Process(target=f, args=())
p.start()
p.join()

## Semi-supervised on Pretrained RotNet 

- please remember to modify the pretrained path

In [None]:
def exp_fixed_num_layers(k):
    paths={'pretrained_resnet18':'/results/PretrainRotNet/20200424-092318/ckpt/best/pretrained_resnet18', 
           'tail':'/results/PretrainRotNet/20200424-092318/ckpt/best/tail'}
    import sys
    sys.path.append('../..')
    import tf_rlib
    from tf_rlib.runners.research import RotNetRunner, SemiRotNetRunner
    FLAGS = tf_rlib.FLAGS

    # tf_rlib.utils.purge_logs()
    tf_rlib.utils.set_gpus('0')
    tf_rlib.utils.set_logging('WARN')

    FLAGS.exp_name='Semi_on_PretrainRotNet_{}'.format(k)

    FLAGS.xla=True
    FLAGS.amp=True

    FLAGS.bs=128
    FLAGS.dim=2
    FLAGS.out_dim=10
    FLAGS.lr=1e-3
    FLAGS.wd=1e-4
    FLAGS.epochs=300
    if FLAGS.amp:
        FLAGS.lr = FLAGS.lr*2
        FLAGS.bs = FLAGS.bs*2

    datasets = tf_rlib.datasets.Cifar10Semi(0.1).get_data()
    runner = SemiRotNetRunner(*datasets)
    runner.load_front_layers(paths['pretrained_resnet18'], k)
    for lay in runner.model.layers:
        print(lay.name, lay.trainable)
    runner.fit(FLAGS.epochs, lr=FLAGS.lr)
    
from multiprocessing import Process
for i in range(1, 6):
    p = Process(target=exp_fixed_num_layers, args=(i,))
    p.start()
    p.join()

## Supervised on Low-label Regiem

In [None]:
def f():
    import sys
    sys.path.append('../..')
    import tf_rlib
    from tf_rlib.runners.research import RotNetRunner, SemiRotNetRunner
    FLAGS = tf_rlib.FLAGS

    # tf_rlib.utils.purge_logs()
    tf_rlib.utils.set_gpus('0')
    tf_rlib.utils.set_logging('WARN')

    FLAGS.exp_name='Supervised_LowLabel'

    FLAGS.xla=True
    FLAGS.amp=True

    FLAGS.bs=128
    FLAGS.dim=2
    FLAGS.out_dim=10
    FLAGS.lr=1e-3
    FLAGS.wd=1e-4
    FLAGS.epochs=300
    if FLAGS.amp:
        FLAGS.lr = FLAGS.lr*2
        FLAGS.bs = FLAGS.bs*2

    datasets = tf_rlib.datasets.Cifar10Semi(0.1).get_data()
    runner = SemiRotNetRunner(*datasets)
    runner.fit(FLAGS.epochs, lr=FLAGS.lr)

from multiprocessing import Process
p = Process(target=f, args=())
p.start()
p.join()

## Unsupervised is validated on Full-label Regiem (add additional one linear layer)

- please remember to modify the pretrained path

In [None]:
def f():
    paths={'pretrained_resnet18':'/results/PretrainRotNet/20200424-092318/ckpt/best/pretrained_resnet18', 
           'tail':'/results/PretrainRotNet/20200424-092318/ckpt/best/tail'}
    import sys
    sys.path.append('../..')
    import tf_rlib
    from tf_rlib.runners.research import RotNetRunner, SemiRotNetRunner
    FLAGS = tf_rlib.FLAGS

    # tf_rlib.utils.purge_logs()
    tf_rlib.utils.set_gpus('0')
    tf_rlib.utils.set_logging('WARN')

    FLAGS.exp_name='Unsupervised_on_PretrainRotNet'

    FLAGS.xla=True
    FLAGS.amp=True

    FLAGS.bs=128
    FLAGS.dim=2
    FLAGS.out_dim=10
    FLAGS.lr=1e-3
    FLAGS.wd=0.0 # 0.0 -> too few trainable parameters
    FLAGS.epochs=300
    if FLAGS.amp:
        FLAGS.lr = FLAGS.lr*2
        FLAGS.bs = FLAGS.bs*2

    datasets = tf_rlib.datasets.Cifar10().get_data()
    runner = SemiRotNetRunner(*datasets)
    runner.load_front_layers(paths['pretrained_resnet18'], 5)
    for lay in runner.model.layers:
        print(lay.name, lay.trainable)
    runner.fit(FLAGS.epochs, lr=FLAGS.lr)

from multiprocessing import Process
p = Process(target=f, args=())
p.start()
p.join()

## Supervised on Full-label Regiem

In [None]:
def f():
    import sys
    sys.path.append('../..')
    import tf_rlib
    from tf_rlib.runners.research import RotNetRunner, SemiRotNetRunner
    FLAGS = tf_rlib.FLAGS

    # tf_rlib.utils.purge_logs()
    tf_rlib.utils.set_gpus('0')
    tf_rlib.utils.set_logging('WARN')

    FLAGS.exp_name='Supervised_FullLabel'

    FLAGS.xla=True
    FLAGS.amp=True

    FLAGS.bs=128
    FLAGS.dim=2
    FLAGS.out_dim=10
    FLAGS.lr=1e-3
    FLAGS.wd=1e-4
    FLAGS.epochs=300
    if FLAGS.amp:
        FLAGS.lr = FLAGS.lr*2
        FLAGS.bs = FLAGS.bs*2

    datasets = tf_rlib.datasets.Cifar10().get_data()
    runner = SemiRotNetRunner(*datasets)
    runner.fit(FLAGS.epochs, lr=FLAGS.lr)

from multiprocessing import Process
p = Process(target=f, args=())
p.start()
p.join()