In [25]:
import argparse
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils import data
import torch
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state 

from utils import get_accuracy
from lars import LARSWrapper
import cifar_100
from logger import log_metrics as logger
import torch_trainer as trainer
from cifar_resnet import resnet18
import cifar_100

#python helper inputs
import os
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
import wandb
import pytorch_lightning as pl
import time



dataset_args = {
                 'crop_size': 32,
                 'brightness': 0.4, 
                 'contrast': 0.4, 
                 'saturation': .2, 
                 'hue': .1, 
                 'color_jitter_prob': .8, 
                 'gray_scale_prob': 0.2, 
                 'horizontal_flip_prob': 0.5, 
                 'gaussian_prob': .5, 
                 'min_scale': 0.16, 
                 'max_scale': 0.9}
val_dataset_args = {
                 'crop_size': 32,
                 'brightness': 0.4, 
                 'contrast': 0.4, 
                 'saturation': .2, 
                 'hue': .1, 
                 'color_jitter_prob': 0, 
                 'gray_scale_prob': 0, 
                 'horizontal_flip_prob': 0.5, 
                 'gaussian_prob': 0, 
                 'min_scale': 0.9, 
                 'max_scale': 1}



def prepare_data(dataset_args, val_dataset_args):
    
    train_dataset = cifar_100.CIFAR_100_transformations(train = True, **dataset_args)
    dataloader = DataLoader(
        train_dataset,
        shuffle = True,
        batch_size=2,
        num_workers=2,
        pin_memory=False,
        drop_last=True,
        persistent_workers=True
    )

    val_dataset = cifar_100.CIFAR_100_transformations(train = True, **val_dataset_args)
    val_dataloader = DataLoader(
        val_dataset,
        shuffle = True,
        batch_size=2,
        num_workers=2,
        pin_memory=False,
        drop_last=True,
        persistent_workers=False
    )
    return train_dataset, dataloader, val_dataloader


def cross_entropy_loss(logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=100)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
        }
    return metrics

def create_train_state(rng, learning_rate, momentum):
    """Creates initial `TrainState`."""
    model = resnet18()
    batch = jnp.ones((4, 32, 32, 3))  # (N, H, W, C) format
    params = model.init(jax.random.PRNGKey(0), batch)['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jax.jit
def train_step(state, batch):

    def loss_fn(params):
        logits, _ = resnet18().apply({'params': params}, batch['image0'], mutable=['batch_stats'])
        loss = cross_entropy_loss(logits=logits, labels=batch['label'])
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=batch['label'])
    
    return state, metrics

In [26]:
type(state)

flax.training.train_state.TrainState

In [27]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
learning_rate = 0.1
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.



for i, data in enumerate(dataloader):
    data['image0'] = jnp.array(data['image0'].permute(0, 2, 3, 1))
    data['label'] = jnp.array(data['label'])
    state, metrics = train_step(state, data)
    #logits = resnet18().apply({'params': state.params}, data['image0'], mutable=['batch_stats'])
    break

In [13]:
out

NameError: name 'out' is not defined

In [2]:
train_dataset, dataloader, val_dataloader = prepare_data(dataset_args, val_dataset_args)

Files already downloaded and verified
Files already downloaded and verified
