In [1]:
import jax, jax.numpy as jp
import flax
import flax.linen as nn

from models.vqgan import VQGAN

In [2]:
from scripts.common import TrainState
from config import VQConfig, AutoencoderConfig

enc_config = AutoencoderConfig(out_channels=256,
                               channel_multipliers=(1,2,4))
dec_config = AutoencoderConfig(out_channels=3,
                               channel_multipliers=(1,2,4))
vq_config = VQConfig(codebook_size=512)

gan = VQGAN(enc_config, dec_config, vq_config)

In [3]:
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (1, 256, 256, 3))
variables = jax.jit(gan.init, static_argnames=['train'])({'params': rng, 'dropout': rng}, x, train=True)

In [4]:
import optax

tx = optax.adam(1e-3)

state = TrainState.create(gan, params=variables.pop('params'), tx=tx, extra_variables=variables)

In [9]:
def train_step(state: TrainState, batch, rng):
    def loss_fn(params):
        x_recon, q_loss, result = state(batch, train=True, params=params, rngs={'dropout': rng})
        loss = optax.l2_loss(batch, x_recon).mean()
        result['recon_loss'] = loss
        result['q_loss'] = q_loss
        return loss, result
    
    state, info = state.apply_loss_fn(loss_fn, has_aux=True)
    return state, info

In [10]:
jit_train_step = jax.jit(train_step)

for _ in range(10):
    rng, subrng = jax.random.split(rng)
    batch = jax.random.normal(rng, (1, 256, 256, 3))
    state, info = jit_train_step(state, x, subrng)
    print(info)

2024-03-21 12:56:20.159418: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 8.17GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-03-21 12:56:22.295753: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 8.13GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


{'codebook_loss': Array(0.18131661, dtype=float32), 'commit_loss': Array(0.04532915, dtype=float32), 'encodings': Array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..

In [11]:
jax.tree_map(jp.shape, info)

{'codebook_loss': (),
 'commit_loss': (),
 'encodings': (1, 64, 64, 512),
 'entropy_loss': (),
 'indices': (1, 64, 64),
 'q_loss': (),
 'recon_loss': ()}

In [15]:
from flax.training import checkpoints
import os, shutil

path = os.path.abspath('./checkpoints')
if os.path.exists(path):
    shutil.rmtree(path)
print(path)
checkpoints.save_checkpoint(path, target=state, step=state.step, keep=1, overwrite=True)

/home/bluesun/PycharmProjects/Repr_Learning/MaskGit/maskgit_jax2/models/mvtm/checkpoints


'/home/bluesun/PycharmProjects/Repr_Learning/MaskGit/maskgit_jax2/models/mvtm/checkpoints/checkpoint_11'

In [16]:
re_state = checkpoints.restore_checkpoint(path, state)



In [19]:
type(re_state)

scripts.common.TrainState

In [21]:
jax.tree_map(jp.shape, re_state.params)

{'conv': {'bias': (256,), 'kernel': (1, 1, 256, 256)},
 'decoder': {'Attention_0': {'Conv_0': {'kernel': (1, 1, 256, 768)},
   'Conv_1': {'kernel': (1, 1, 256, 256)},
   'GroupNorm_0': {'bias': (256,), 'scale': (256,)}},
  'ConvIn': {'bias': (256,), 'kernel': (3, 3, 256, 256)},
  'ConvOut': {'kernel': (3, 3, 64, 3)},
  'GroupNorm_0': {'bias': (64,), 'scale': (64,)},
  'ResBlock_0': {'Conv_0': {'kernel': (3, 3, 256, 256)},
   'Conv_1': {'kernel': (1, 1, 256, 256)},
   'GroupNorm_0': {'bias': (256,), 'scale': (256,)},
   'GroupNorm_1': {'bias': (256,), 'scale': (256,)}},
  'ResBlock_1': {'Conv_0': {'kernel': (3, 3, 256, 256)},
   'Conv_1': {'kernel': (1, 1, 256, 256)},
   'GroupNorm_0': {'bias': (256,), 'scale': (256,)},
   'GroupNorm_1': {'bias': (256,), 'scale': (256,)}},
  'UpBlock_0': {'ResBlock_0': {'Conv_0': {'kernel': (3, 3, 256, 256)},
    'Conv_1': {'kernel': (1, 1, 256, 256)},
    'GroupNorm_0': {'bias': (256,), 'scale': (256,)},
    'GroupNorm_1': {'bias': (256,), 'scale': (25