In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [3]:
import torch
from diffusers import StableDiffusion3Pipeline

ModuleNotFoundError: No module named 'diffusers'

In [28]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(device)
# from https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_3
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="balanced")

cuda


Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [23]:
print(torch.cuda.memory_allocated() / 1e6, "MB allocated")
print(torch.cuda.memory_reserved() / 1e6, "MB reserved")

23821.2096 MB allocated
24192.745472 MB reserved


In [25]:
torch.cuda.empty_cache()

In [32]:
transformer = pipe.transformer
vae = pipe.vae
scheduler = pipe.scheduler

In [34]:
print("DiT:", transformer.config)
print("VAE:", vae.config)
print("Scheduler:", scheduler)

DiT: FrozenDict([('sample_size', 128), ('patch_size', 2), ('in_channels', 16), ('num_layers', 24), ('attention_head_dim', 64), ('num_attention_heads', 24), ('joint_attention_dim', 4096), ('caption_projection_dim', 1536), ('pooled_projection_dim', 2048), ('out_channels', 16), ('pos_embed_max_size', 192), ('dual_attention_layers', ()), ('qk_norm', None), ('_use_default_values', ['dual_attention_layers', 'qk_norm']), ('_class_name', 'SD3Transformer2DModel'), ('_diffusers_version', '0.29.0.dev0'), ('_name_or_path', '/home/irisx/.cache/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671/transformer')])
VAE: FrozenDict([('in_channels', 3), ('out_channels', 3), ('down_block_types', ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D']), ('up_block_types', ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']), ('block_out_channels', [128, 256, 512, 512]), ('lay

In [35]:
text_encoder = CanineModel.from_pretrained("google/canine-c", torch_dtype=torch.float16)
text_tokenizer = AutoTokenizer.from_pretrained("google/canine-c")

config.json:   0%|          | 0.00/698 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/528M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/892 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/657 [00:00<?, ?B/s]

In [37]:
print("Text Encoder:", text_encoder.config)
print("Text Tokenizer:", text_tokenizer)

Text Encoder: CanineConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "CanineModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 57344,
  "downsampling_rate": 4,
  "eos_token_id": 57345,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "local_transformer_stride": 128,
  "max_position_embeddings": 16384,
  "model_type": "canine",
  "num_attention_heads": 12,
  "num_hash_buckets": 16384,
  "num_hash_functions": 8,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "torch_dtype": "float16",
  "transformers_version": "4.51.3",
  "type_vocab_size": 16,
  "upsampling_kernel_size": 4,
  "use_cache": true
}

Text Tokenizer: CanineTokenizer(name_or_path='google/canine-c', vocab_size=1114112, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '\ue000', 'eos_token': '\ue001', 'sep_to

In [39]:
import torch.nn as nn

In [41]:
class Conditioner(nn.Module):
    def __init__(self, text_dim=768, style_dim=256, hidden_dim=1024):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(text_dim + style_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)  # Project to DiT's cross-attn dimension
        )
    
    def forward(self, text_embeds, style_embeds):
        combined = torch.cat([text_embeds, style_embeds], dim=-1)
        return self.proj(combined)  # [batch, seq_len, hidden_dim]

conditioner = Conditioner()#.to("cuda").half()

In [42]:
print(conditioner)

Conditioner(
  (proj): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=1024, out_features=1024, bias=True)
  )
)


In [51]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v3_small

class StyleEncoder(nn.Module):
    def __init__(self, style_embed_dim=256):
        super().__init__()
        self.base = mobilenet_v3_small(pretrained=False)  # Match DiffusionPen's backbone
        self.base.classifier = nn.Sequential(
            nn.Linear(576, style_embed_dim),
            nn.GELU()
        )
    
    def forward(self, image):
        return self.base(image)

style_encoder = StyleEncoder()
pretrained_weights = torch.load("DiffusionPen/style_models/iam_style_diffusionpen.pth", map_location="cuda:0")
style_encoder.load_state_dict(pretrained_weights.keys())

# Move to GPU and FP16
style_encoder = style_encoder.half()#.to("cuda").half()

odict_keys(['model.conv_stem.weight', 'model.bn1.weight', 'model.bn1.bias', 'model.bn1.running_mean', 'model.bn1.running_var', 'model.bn1.num_batches_tracked', 'model.blocks.0.0.conv_dw.weight', 'model.blocks.0.0.bn1.weight', 'model.blocks.0.0.bn1.bias', 'model.blocks.0.0.bn1.running_mean', 'model.blocks.0.0.bn1.running_var', 'model.blocks.0.0.bn1.num_batches_tracked', 'model.blocks.0.0.conv_pw.weight', 'model.blocks.0.0.bn2.weight', 'model.blocks.0.0.bn2.bias', 'model.blocks.0.0.bn2.running_mean', 'model.blocks.0.0.bn2.running_var', 'model.blocks.0.0.bn2.num_batches_tracked', 'model.blocks.1.0.conv_pw.weight', 'model.blocks.1.0.bn1.weight', 'model.blocks.1.0.bn1.bias', 'model.blocks.1.0.bn1.running_mean', 'model.blocks.1.0.bn1.running_var', 'model.blocks.1.0.bn1.num_batches_tracked', 'model.blocks.1.0.conv_dw.weight', 'model.blocks.1.0.bn2.weight', 'model.blocks.1.0.bn2.bias', 'model.blocks.1.0.bn2.running_mean', 'model.blocks.1.0.bn2.running_var', 'model.blocks.1.0.bn2.num_batches_tr

TypeError: Expected state_dict to be dict-like, got <class 'odict_keys'>.

In [None]:
from datasets import load_dataset

class HandwritingDataset(torch.utils.data.Dataset):
    def __init__(self, split="train"):
        self.data = load_dataset("your_dataset_name", split=split)
        
    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "image": item["image"],  # Preprocess to [3, H, W]
            "text": item["text"],
            "writer_id": item["writer_id"] #style
        }
    
    def __len__(self):
        return len(self.data)

# Create dataloader
dataset = HandwritingDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [None]:
# Optimizers
optimizer = torch.optim.AdamW(
    list(style_encoder.parameters()) + list(conditioner.parameters()),
    lr=1e-4
)

# Training
for epoch in range(100):
    for batch in dataloader:
        # Move data to GPU
        images = batch["image"].to("cuda", dtype=torch.float16)
        texts = batch["text"]
        writer_ids = batch["writer_id"]
        
        # Encode text
        text_inputs = text_tokenizer(texts, return_tensors="pt", padding=True)
        text_embeds = text_encoder(**text_inputs).last_hidden_state  # [batch, seq_len, 768]
        
        # Encode style
        style_embeds = style_encoder(images)  # [batch, 256]
        
        # Combine with conditioner
        cond_embeds = conditioner(text_embeds, style_embeds.unsqueeze(1))  # [batch, seq_len, 1024]
        
        # VAE Encode
        latents = vae.encode(images).latent_dist.sample() * 0.18215
        
        # Sample noise
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (latents.shape[0],))
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)
        
        # Predict noise (DiT forward)
        noise_pred = transformer(
            noisy_latents,
            timesteps,
            encoder_hidden_states=cond_embeds  # Our fused conditioning!
        ).sample
        
        # Loss
        loss = nn.functional.mse_loss(noise_pred, noise)
        
        # Backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    print(f"Epoch {epoch}, Loss: {loss.item()}")

In [None]:
def generate_handwriting(text, style_image, num_inference_steps=20):
    # Encode style
    style_emb = style_encoder(style_image.unsqueeze(0).to("cuda").half())
    
    # Encode text
    text_inputs = text_tokenizer([text], return_tensors="pt")
    text_emb = text_encoder(**text_inputs.to("cuda")).last_hidden_state
    
    # Fuse conditioning
    cond_emb = conditioner(text_emb, style_emb.unsqueeze(1))
    
    # Sample latents
    latents = torch.randn((1, 4, 32, 32), device="cuda", dtype=torch.float16)
    
    # Denoise
    scheduler.set_timesteps(num_inference_steps)
    for t in scheduler.timesteps:
        noise_pred = transformer(
            latents,
            t,
            encoder_hidden_states=cond_emb
        ).sample
        
        latents = scheduler.step(noise_pred, t, latents).prev_sample
    
    # Decode
    image = vae.decode(latents / 0.18215).sample
    return image.detach().cpu().permute(0, 2, 3, 1).numpy()[0]