In [1]:
# %pip install matplotlib

In [2]:
import pandas as pd

splits = {'train': 'data/train-00000-of-00001-e4094e9912b63f1e.parquet', 'test': 'data/test-00000-of-00001-f13470702cfdb4c7.parquet'}
df = pd.read_parquet("hf://datasets/Babypotatotang/lld-onlyicon/" + splits["train"])

In [3]:
# Text has a given structure, let's separate it into columns
# e.g. "A logo of sincere brand Propstack in Real Estate industry" into "sincere", "Propstack", "Real Estate"
def extract_columns(text):
    parts = text.split(" ")
    assert len(parts) >= 9, f"Unexpected text format (too short): {text}"
    assert parts[0] == "A" and parts[1] == "logo" and parts[2] == "of", f"Unexpected text format: {text}"
    assert parts[-1] == "industry", f"Unexpected text format: {text}"
    assert "in" in parts and "brand" in parts, f"Unexpected text format: {text}"
    adjective = parts[3]
    in_index = parts.index("in")
    industry = " ".join(parts[in_index + 1:-1])
    brand_index = parts.index("brand")
    brand = " ".join(parts[brand_index + 1:in_index])
    return pd.Series([adjective, brand, industry])

# Try it on a few samples
for i in range(5):
    print(extract_columns(df.iloc[i]["text"]))

0        sincere
1      Propstack
2    Real Estate
dtype: object
0    sophisticated
1        Propstack
2      Real Estate
dtype: object
0              sincere
1              Cadcorp
2    Computer Software
dtype: object
0                               sincere
1                                Oceana
2    Non-profit Organization Management
dtype: object
0                    sincere
1                    Amplifr
2    Marketing & Advertising
dtype: object


In [4]:
# Apply to the whole dataframe
df[["adjective", "brand", "industry"]] = df["text"].apply(extract_columns)
df.head()

Unnamed: 0,image,text,adjective,brand,industry
0,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,A logo of sincere brand Propstack in Real Esta...,sincere,Propstack,Real Estate
1,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,A logo of sophisticated brand Propstack in Rea...,sophisticated,Propstack,Real Estate
2,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,A logo of sincere brand Cadcorp in Computer So...,sincere,Cadcorp,Computer Software
3,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,A logo of sincere brand Oceana in Non-profit O...,sincere,Oceana,Non-profit Organization Management
4,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,A logo of sincere brand Amplifr in Marketing &...,sincere,Amplifr,Marketing & Advertising


In [5]:
# See values of each new column
print("Adjectives:", df["adjective"].unique(), f"({len(df['adjective'].unique())} unique)")
print("Brands:", df["brand"].unique(), f"({len(df['brand'].unique())} unique)")
print("Industries:", df["industry"].unique(), f"({len(df['industry'].unique())} unique)")

Adjectives: ['sincere' 'sophisticated' 'exciting' 'rugged'] (4 unique)
Brands: ['Propstack' 'Cadcorp' 'Oceana' ... 'PHI' 'DesignHammer' 'IndieBox'] (9070 unique)
Industries: ['Real Estate' 'Computer Software' 'Non-profit Organization Management'
 'Marketing & Advertising' 'Internet' 'Apparel & Fashion' 'Machinery'
 'Construction' 'Information Technology & Services' 'Retail'
 'Consumer Goods' 'Computer Hardware' 'Computer Games'
 'Environmental Services' 'Gambling & Casinos' 'Automotive' 'Oil & Energy'
 'Civic & Social Organization' 'Airlines/Aviation' 'Utilities'
 'E-learning' 'Health, Wellness & Fitness' 'Food & Beverages' 'Cosmetics'
 'Broadcast Media' 'Online Media' 'Electrical & Electronic Manufacturing'
 'Financial Services' 'Hospital & Health Care' 'Pharmaceuticals'
 'Restaurants' 'Leisure, Travel & Tourism' 'Photography' 'Research'
 'Publishing' 'Information Services' 'Public Relations & Communications'
 'Consumer Electronics' 'Insurance' 'Media Production'
 'Architecture & Plan

In [6]:
# CONSTANTS
BATCH_SIZE = 32
IMG_SIZE = 64
IN_CHANNELS = 3
BASE_CHANNELS = 64
TIME_EMB_DIM = 128
COND_EMB_DIM = 64
LR = 1e-4
T = 500     # diffusion steps
EPOCHS = 200
CFG_DROP_CHANCE = 0.1  # chance to drop conditioning for classifier-free guidance

In [7]:
import torch
from torch.utils.data import DataLoader
from dataset import LogoDataset
    
adjectives = sorted(df["adjective"].unique())
industries = sorted(df["industry"].unique())

num_adjectives = len(adjectives)
num_industries = len(industries)

adjective_to_idx = {adj: i for i, adj in enumerate(adjectives)}
industry_to_idx = {ind: i for i, ind in enumerate(industries)}

dataset = LogoDataset(df, adjective_to_idx, industry_to_idx, img_size=64)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=False)

In [8]:
from model import UNet, cosine_beta_schedule, q_sample

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model = UNet(
    num_adjectives=num_adjectives + 1,  # +1 for unconditioned case
    num_industries=num_industries + 1,  # +1 for unconditioned case
    in_channels=IN_CHANNELS,
    base_channels=BASE_CHANNELS,
    time_emb_dim=TIME_EMB_DIM,
    cond_emb_dim=COND_EMB_DIM
).to(device)

print("Number of trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=1e-4)
scaler = torch.amp.GradScaler('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
num_training_steps = EPOCHS * len(dataloader)
print("Total training steps:", num_training_steps)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_training_steps)
max_grad_norm = 1.0

betas = cosine_beta_schedule(T)
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0).to(device)

Number of trainable parameters: 1698979
Total training steps: 93600


In [9]:
from tqdm import tqdm

for epoch in range(EPOCHS):
    model.train()
    loop = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{EPOCHS}]")
    for batch in loop:
        imgs, adjective_ids, industry_ids = batch
        imgs = imgs.to(device)
        adjective_ids = adjective_ids.to(device)
        industry_ids = industry_ids.to(device)
        batch_dim = imgs.shape[0]

        optimizer.zero_grad()

        # classifier-free guidance dropout
        drop_mask = (torch.rand(batch_dim, device=device) < CFG_DROP_CHANCE)
        adjective_ids[drop_mask] = -1  # unconditioned
        industry_ids[drop_mask] = -1  # unconditioned
        t = torch.randint(0, T, (batch_dim,), device=device).long()
        epsilon = torch.randn_like(imgs)

        with torch.amp.autocast('cuda'):
            x_t = q_sample(imgs, t, epsilon, alphas_bar)
            pred_noise = model(x_t, t, adjective_ids, industry_ids)
            loss = torch.nn.functional.mse_loss(pred_noise, epsilon)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        loop.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])

Epoch [1/200]:   0%|          | 0/468 [00:16<?, ?it/s]


TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.