In [16]:
import pandas as pd
import matplotlib.pyplot as plt

import numpy as np

import jax
import jax.numpy as jnp
from jax import grad, jit

from jax.tree_util import tree_map
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score

import optax

from flax import linen as nn
from flax.training import train_state
from flax.serialization import (
    to_state_dict, msgpack_serialize, from_bytes
)

import os
import wandb
from typing import Callable
from tqdm.notebook import tqdm

In [None]:
%%writefile train_zl.sh
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"

In [None]:
!export TEST_NAME='tyup'
!echo $TEST_NAME

In [None]:
ecom_data = pd.read_csv('/kaggle/input/jax-datasets/ecommerce_data.csv')
ecom_data.head()

In [None]:
ecom_data = ecom_data.drop(['Email','Avatar','Address'], axis = 1)
ecom_data.shape

In [None]:
X = ecom_data.drop('Yearly Amount Spent', axis = 1)
y = ecom_data['Yearly Amount Spent']

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

In [None]:
scaler = StandardScaler()

X_train_scaled = pd.DataFrame(scaler.fit_transform(X_train), columns=X_train.columns)
X_test_scaled = pd.DataFrame(scaler.fit_transform(X_test), columns=X_test.columns)

In [None]:
X_train_scaled.describe()

In [None]:
X_train, X_test, y_train, y_test = jnp.array(X_train_scaled.to_numpy(), dtype = jnp.float32), jnp.array(X_test_scaled.to_numpy(), dtype = jnp.float32), jnp.array(y_train.to_numpy(), dtype = jnp.float32), jnp.array(y_test.to_numpy(), dtype = jnp.float32)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

In [None]:
W = jnp.zeros(X_train.shape[1:])

b = 0.

lr = 0.01

n_iter = 500

In [None]:
def predict_y(W, b, X):
    return jnp.dot(X, W) + b

In [None]:
def loss_fn(W, b, X, y):
    error = predict_y(W, b, X) - y
    return jnp.mean(jnp.square(error))

In [None]:
def update_W(W, b, X, y, lr):
    grad_W = grad(loss_fn, argnums = 0)(W, b, X, y)
    
    return tree_map(lambda W, graad_W, lr : W - lr*grad_W, W, grad_W, lr)
    
def update_b(W, b, X, y, lr):    
    grad_b = grad(loss_fn, argnums = 1)(W, b, X, y)
    
    return tree_map(lambda b, graad_b, lr : b - lr*grad_b, b, grad_b, lr)

In [None]:
loss_hist = []

for i in range(n_iter):
    loss = loss_fn(W,b, X_train, y_train)
    
    if (i + 1) % 100 == 0:
        print ('Iteration', i+1, 'Loss:', loss)
        
    loss_hist.append(loss)
    
    W = jit(update_W)(W, b, X_train, y_train, lr)
    b = jit(update_b)(W, b, X_train, y_train, lr)    

In [None]:
_, ax = plt.subplots(figsize = (12,8))
    
ax.set(xlabel = 'Iteration', ylabel = 'Loss', title = 'Training Loss per Epoch')
plt.plot(loss_hist)   

In [None]:
y_pred = predict_y(W, b, X_test)

In [None]:
print('Test Score: ', r2_score(y_test, y_pred))

In [None]:
def ll_distance(x, y):
    assert x.ndim == y.ndim == 1
    return jnp.sum(jnp.abs(x - y))

In [None]:
xs = jax.random.normal(jax.random.PRNGKey(0), (100, 3))

In [None]:
print(xs)

In [None]:
def pairwise_distances(dist, xs):
    return jax.vmap(jax.vmap(dist, (0, None)), (None, 0))(xs, xs)

ys = pairwise_distances(ll_distance, xs)
print(ys[0])

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
validation_split = 0.2
batch_size = 64

(full_train_set, test_dataset), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)


# Similarly, we apply the same transforms to the
# validation and test dataset

In [None]:
def normalize_img(image, label):
    image = tf.cast(image, tf.float32) / 255.
    return image, label

full_train_set = full_train_set.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE
)

num_data = tf.data.experimental.cardinality(
    full_train_set
).numpy()
print("Total number of data points:", num_data)
train_dataset = full_train_set.take(
    num_data * (1 - validation_split)
)
val_dataset = full_train_set.take(
    num_data * (validation_split)
)
print(
    "Number of train data points:",
    tf.data.experimental.cardinality(train_dataset).numpy()
)
print(
    "Number of val data points:",
    tf.data.experimental.cardinality(val_dataset).numpy()
)

train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(
    tf.data.experimental.cardinality(train_dataset).numpy()
)
train_dataset = train_dataset.batch(batch_size)

val_dataset = val_dataset.cache()
val_dataset = val_dataset.shuffle(
    tf.data.experimental.cardinality(val_dataset).numpy()
)
val_dataset = val_dataset.batch(batch_size)


test_dataset = test_dataset.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE
)
print(
    "Number of test data points:",
    tf.data.experimental.cardinality(test_dataset).numpy()
    )
test_dataset = test_dataset.cache()
test_dataset = test_dataset.batch(batch_size)

In [None]:
train_datagen = iter(tfds.as_numpy(train_dataset))
next(train_datagen)[1]

In [None]:
ds_info

# WandB Tut

In [23]:
seed = 42
pooling = "avg"
batch_size = 4

MODULE_DICT = {
    "avg": nn.avg_pool,
    "max": nn.max_pool,
}

In [24]:
class CNN(nn.Module):
    pool_module: Callable = nn.avg_pool


    def setup(self):
        self.conv_1 = nn.Conv(features=32, kernel_size=(3, 3))
        self.conv_2 = nn.Conv(features=32, kernel_size=(3, 3))
        self.conv_3 = nn.Conv(features=64, kernel_size=(3, 3))
        self.conv_4 = nn.Conv(features=64, kernel_size=(3, 3))
        self.conv_5 = nn.Conv(features=128, kernel_size=(3, 3))
        self.conv_6 = nn.Conv(features=128, kernel_size=(3, 3))
        self.dense_1 = nn.Dense(features=1024)
        self.dense_2 = nn.Dense(features=512)
        self.dense_output = nn.Dense(features=10)


    @nn.compact
    def __call__(self, x):
        x = nn.relu(self.conv_1(x))
        x = nn.relu(self.conv_2(x))
        x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.relu(self.conv_3(x))
        x = nn.relu(self.conv_4(x))
        x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.relu(self.conv_5(x))
        x = nn.relu(self.conv_6(x))
        x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.relu(self.dense_1(x))
        x = nn.relu(self.dense_2(x))
        return self.dense_output(x)

#     @nn.compact
#     def __call__(self, x):
#         x = nn.relu(nn.Conv(features=32, kernel_size=(3, 3))(x))
#         x = nn.relu(nn.Conv(features=32, kernel_size=(3, 3))(x))
#         x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
#         x = nn.relu(nn.Conv(features=64, kernel_size=(3, 3))(x))
#         x = nn.relu(nn.Conv(features=64, kernel_size=(3, 3))(x))
#         x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
#         x = nn.relu(nn.Conv(features=128, kernel_size=(3, 3))(x))
#         x = nn.relu(nn.Conv(features=128, kernel_size=(3, 3))(x))
#         x = self.pool_module(x, window_shape=(2, 2), strides=(2, 2))
#         x = x.reshape((x.shape[0], -1))
#         x = nn.relu(nn.Dense(features=1024)(x))
#         x = nn.relu(nn.Dense(features=512)(x))
#         return nn.Dense(features=10)(x)


In [25]:
rng = jax.random.PRNGKey(seed) # PRNG Key
x = jnp.ones(shape=(batch_size, 32, 32, 3)) # Dummy Input
model = CNN(pool_module=MODULE_DICT[pooling]) # Instantiate the Model
params = model.init(rng, x) # Initialize the parameters
jax.tree_map(lambda x: x.shape, params) # Check the parameters

2023-11-28 09:43:09.463798: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to load in-memory CUBIN (compiled for a different GPU?).: CUDA_ERROR_OUT_OF_MEMORY: out of memory


In [None]:
type(params['params'].keys())

In [None]:
tf.data.experimental.cardinality(train_dataset).numpy()

In [None]:
def init_train_state(
    model, random_key, shape, learning_rate
) -> train_state.TrainState:
    # Initialize the Model
    variables = model.init(random_key, jnp.ones(shape))
    # Create the optimizer
    optimizer = optax.adam(learning_rate)
    # Create a State
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )


learning_rate = 0.01

state = init_train_state(
    model, rng, (batch_size, 32, 32, 3), learning_rate
)

In [None]:
next(iter(tfds.as_numpy(train_dataset)))[0][0].shape

In [None]:
# model.apply(params, next(iter(tfds.as_numpy(train_dataset)))[0][0].reshape(1,32,32,3))
new_state = state = init_train_state(
    model, rng, (1, 32, 32, 3), learning_rate
)

new_state.apply_fn({'params': new_state.params}, next(iter(tfds.as_numpy(train_dataset)))[0][0].reshape(1,32,32,3))

In [None]:
params.keys()

In [None]:
params['params']['conv_1']['kernel'].shape

In [None]:
NUM_DEVICES = jax.device_count()
NUM_DEVICES

In [None]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

In [None]:
x = jnp.arange(64 * 64).reshape(64, 64)
jax.debug.visualize_array_sharding(x)

In [None]:
x.sharding

In [None]:
sharding = PositionalSharding(jax.local_devices())
sharding.reshape(2, 1)

In [None]:
y = jax.device_put(x, sharding.reshape(2, 1).replicate(0))
jax.debug.visualize_array_sharding(y)

In [None]:
t = x.reshape(64, 64)
k = jax.device_put(t, sharding.reshape(2, 1))
k = k.reshape(64, 8, 8, 1)
len(k.addressable_shards)

In [None]:
t.reshape(8, 8, 64, 1).shape

In [None]:
@jax.pmap
def g(x):
    return x

p = g(t.reshape(2, 32, 64, 1))
p.sharding

In [None]:
new_ds = (jax.random.normal(jax.random.PRNGKey(1234), (1,32,32,3)).max()-jax.random.normal(jax.random.PRNGKey(1234), (1,32,32,3)))/(jax.random.normal(jax.random.PRNGKey(1234), (1,32,32,3)).max()-jax.random.normal(jax.random.PRNGKey(1234), (1,32,32,3)).min())
new_state.apply_fn({'params': new_state.params}, new_ds)

In [None]:
(train_ds, test_ds), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [None]:
def show_img(img, ax=None, title=None):
  """Shows a single image."""
  if ax is None:
    ax = plt.gca()
  ax.imshow(img[..., 0], cmap='gray')
  ax.set_xticks([])
  ax.set_yticks([])
  if title:
    ax.set_title(title)

def show_img_grid(imgs, titles):
  """Shows a grid of images."""
  n = int(np.ceil(len(imgs)**.5))
  _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
  for i, (img, title) in enumerate(zip(imgs, titles)):
    show_img(img, axs[i // n][i % n], title)

In [None]:
show_img_grid(
    [next(train_ds.as_numpy_iterator())[0] for idx in range(25)],
    [f'label={next(train_ds.as_numpy_iterator())[0][idx]}' for idx in range(25)],
)

In [1]:
!pip install -qq --upgrade transformers diffusers

In [2]:
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
import torch

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae, torch_dtype=torch.float16,
).to("cuda")
style_lora = "johnowhitaker/lora-sdxl-njstyle"
pipeline.load_lora_weights(style_lora)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


Downloading config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

Downloading (…)ch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Downloading model_index.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

Fetching 17 files:   0%|          | 0/17 [00:00<?, ?it/s]

Downloading (…)cheduler_config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

Downloading tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

Downloading (…)ncoder_2/config.json:   0%|          | 0.00/575 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/2.78G [00:00<?, ?B/s]

Downloading (…)_encoder/config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

Downloading tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

Downloading unet/config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

Downloading (…)ch_model.safetensors:   0%|          | 0.00/10.3G [00:00<?, ?B/s]

Downloading (…)ch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Downloading tokenizer_2/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Downloading (…)_weights.safetensors:   0%|          | 0.00/372M [00:00<?, ?B/s]

In [3]:
lora_sd, alphas = pipeline.lora_state_dict("johnowhitaker/lora-sdxl-plushie")


Downloading (…)_weights.safetensors:   0%|          | 0.00/372M [00:00<?, ?B/s]

In [4]:
i = 1
for k, v in lora_sd.items():
    if i == 1001:
        print(k)
        print(v)
        print(lora_sd[k] is v)
        break
        
    i += 1

unet.unet.up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k.lora.down.weight
tensor([[-0.0215,  0.0165, -0.0238,  ..., -0.0112, -0.0065, -0.0004],
        [-0.0090,  0.0174, -0.0243,  ...,  0.0274,  0.0008, -0.0087],
        [-0.0150,  0.0116, -0.0104,  ..., -0.0006,  0.0188,  0.0111],
        ...,
        [-0.0003, -0.0251, -0.0076,  ..., -0.0108, -0.0131,  0.0119],
        [ 0.0038, -0.0006,  0.0019,  ...,  0.0041,  0.0177, -0.0236],
        [-0.0035, -0.0103,  0.0343,  ...,  0.0261, -0.0160, -0.0171]])
True


In [5]:
k.split(".lora.")[0]

'unet.unet.up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k'

In [6]:
target_layer = k.split(".lora.")[0]
# target_layer = target_layer.replace(f".0", f"[0]")
for i in range(10):target_layer = target_layer.replace(f".{i}", f"[{i}]")
target_layer[5:]

'unet.up_blocks[0].attentions[2].transformer_blocks[8].attn2.to_k'

In [7]:
lora_sd['unet.unet.up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k.lora.down.weight'].numpy().shape

(64, 2048)

In [8]:
eval(f"pipeline.{target_layer[5:]}").weight.data.shape

torch.Size([1280, 2048])

In [9]:
list(lora_sd.keys())[2:4]

['unet.unet.down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.lora.down.weight',
 'unet.unet.down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.lora.up.weight']

In [10]:
pipeline.unet.up_blocks[0].attentions[2].transformer_blocks[8].attn2.to_k.weight.data

tensor([[ 0.0013, -0.0101,  0.0103,  ..., -0.0008,  0.0061,  0.0002],
        [ 0.0096,  0.0074, -0.0222,  ..., -0.0037, -0.0011, -0.0003],
        [-0.0041,  0.0070, -0.0100,  ..., -0.0049,  0.0100,  0.0035],
        ...,
        [-0.0035,  0.0042,  0.0114,  ...,  0.0191,  0.0088,  0.0200],
        [-0.0049, -0.0021,  0.0037,  ..., -0.0096, -0.0137, -0.0017],
        [-0.0037,  0.0082, -0.0078,  ...,  0.0092,  0.0126, -0.0049]],
       device='cuda:0', dtype=torch.float16)

In [26]:
pipeline.vae.post_quant_conv.weight.data.shape

torch.Size([4, 4, 1, 1])

In [12]:
param_count = 0

def target_layer_from_sd_name(k):
    # They use slightly different naming schemes for attn processors vs the rest
    if '.processor.to_' in k:
        target_layer = k.split("processor.to_")[0] + k.split(".processor.")[1].split("_lora")[0]
        target_layer = target_layer.replace("to_out", "to_out[0]")
    else:
        target_layer = k.split(".lora.")[0]
    # Replace '.1.' with '[1]' and so on:
    for i in range(10):target_layer = target_layer.replace(f".{i}", f"[{i}]")
    # Return (skipping the first 'unet.' in this case):
    return target_layer[5:]

for k, v in lora_sd.items():
    target_layer = target_layer_from_sd_name(k)
    aa, bb = eval(f'pipeline.{target_layer}').weight.data.shape
    param_count += aa+bb
    
print(f"{param_count:,d}")    

2,903,040


In [27]:
k

'unet.unet.up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v.lora.up.weight'

In [14]:
pipeline.unet

UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): LoRACompatibleLinear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
  )
  (add_time_proj): Timesteps()
  (add_embedding): TimestepEmbedding(
    (linear_1): LoRACompatibleLinear(in_features=2816, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): LoRACompatibleLinear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
          (conv1): LoRACompatibleConv(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
          (

In [68]:
list(lora_sd.items())[0][0]

look_up_list = ['.'+str(i) for i in range(10)]

# base_params['unet']['up_blocks_0']['attentions_2']['transformer_blocks_8']['attn2']['to_k']
# unet.unet.down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k

def target_layer_from_sd_name(k):
    lora_split_layer = k.split(".lora.")[0][5:]
    for i in range(10):lora_split_layer = lora_split_layer.replace(f".{i}", f"_{i}")
    layers = lora_split_layer.split('.')
    []
    target_layer = ''
    for layer in layers:
        target_layer += f"['{layer}']"
    
    return target_layer


target_layer_from_sd_name(list(lora_sd.items())[0][0])
# param_count = 0

# def target_layer_from_sd_name(k):
#     # They use slightly different naming schemes for attn processors vs the rest
#     if '.processor.to_' in k:
#         target_layer = k.split("processor.to_")[0] + k.split(".processor.")[1].split("_lora")[0]
#         target_layer = target_layer.replace("to_out", "to_out[0]")
#     else:
#         target_layer = k.split(".lora.")[0]
#     # Replace '.1.' with '[1]' and so on:
#     for i in range(10):target_layer = target_layer.replace(f".{i}", f"[{i}]")
#     # Return (skipping the first 'unet.' in this case):
#     return target_layer[5:]

# for k, v in lora_sd.items():
#     target_layer = target_layer_from_sd_name(k)
#     aa, bb = eval(f'pipeline.{target_layer}').weight.data.shape
#     param_count += aa+bb

"['unet']['down_blocks_1']['attentions_0']['transformer_blocks_0']['attn1']['to_k']"

In [None]:
['unet']['up_blocks_0']['attentions_2']['transformer_blocks_8']['attn2']['to_k']

In [58]:
type(int(list(lora_sd.items())[0][0].split(".lora.")[0].split('.')[4])) == int

['.0', '.1', '.2', '.3', '.4', '.5', '.6', '.7', '.8', '.9']