In [None]:
!pip install imagen-pytorch
!pip install -U sentence-transformers
!pip install einops-exts

In [None]:
import glob
import torch
import pandas as pd

from torchvision.transforms import ToTensor
from PIL import Image

from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import Compose


class AnimeFaceDataset(Dataset):
    def __init__(self, path_to_dataset, path_to_faces_labels, transform):
        self.imgs_path = path_to_dataset
        # retrieve the list of images in the specified path. Path must end with /
        self.images_names = glob.glob(self.imgs_path + "*") 
        # eventualmente quello di sotto puo' essere fatto senza pandas ? Vedere se conviene dataframe o dizionario
        labels_df = pd.read_csv(path_to_faces_labels)
        self.anime_faces_labels = labels_df.set_index('image')['prompt'].to_dict()
        self.transform = transform

    def __len__(self):
        return len(self.images_names)
  
    def __getitem__(self, idx):
        img_path = self.images_names[idx]
        img_name = img_path.split("/")[-1]
        img_labels = self.anime_faces_labels[img_name]
        # Carica l'immagine utilizzando PIL
        image = Image.open(img_path).convert("RGB")
        # Converte l'immagine in un tensore
        img_tensor = ToTensor()(image)
        if self.transform is not None:
            img_tensor = self.transform(img_tensor)
            
        return img_tensor, img_labels

In [None]:
from einops import rearrange

import torch
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

# DA VEDERE
def get_prompts_embeddings(prompts, encoder, tokenizer):
    encoded = tokenizer.batch_encode_plus(
        prompts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = 256,
        truncation = True
    )
    
    with torch.no_grad():
        output = encoder(input_ids=encoded.input_ids , attention_mask=encoded.attention_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = encoded.attention_mask.bool()
    
    return encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.)

def train(dataloader, imagen, encoder, tokenizer, trainer, unet_to_train, imagen_test): 
    i = 0
    
    for idx,batch in enumerate(dataloader):
        images, prompts = batch[0], batch[1]
        prompts = list(prompts)
        # prendere le immagini 
        images = images.cuda()
        # trasformare labels in text embeddings
    
        prompts_embeddings = get_prompts_embeddings(prompts, encoder, tokenizer).cuda()
                
        loss_unet = trainer(
            images,
            text_embeds = prompts_embeddings,
            unet_number = unet_to_train,            # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
            max_batch_size = 8        # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
        )
        
        trainer.update(unet_number = unet_to_train)
            
        print(i)
    
        i += 1        

In [None]:
import torch
import torchvision

from transformers import T5Tokenizer, T5EncoderModel

from imagen_pytorch import Unet, SRUnet256, Imagen, ImagenTrainer
from torch.utils.data import DataLoader

# unet for imagen

unet1 = Unet(
    dim = 256,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)


# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = unet1,
    text_encoder_name = 't5-large',
    image_sizes = 64,
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

trainer = ImagenTrainer(imagen).cuda() 

In [None]:
model_name = 'google/t5-v1_1-large'
tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=256)
encoder = T5EncoderModel.from_pretrained(model_name)

In [None]:
size_of_batch = 8
train_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((64,64)),
    ])

anime_face_dataset = AnimeFaceDataset("../input/another-anime-face-dataset/animefaces256cleaner/", "../input/anime-faces-labels-cleaned/anime_faces_labels_cleaned.csv", train_transform)
dataloader = DataLoader(anime_face_dataset, batch_size=size_of_batch, shuffle=False)

In [None]:
train(dataloader, imagen, encoder, tokenizer, trainer, 1)

In [None]:
images = imagen.sample(
    texts = [
        'anime girl with long red hair, blue eyes, smug face',
        'anime girl with short blue hair, red eyes'
    ],
    start_at_unet_number = 1,              
    cond_scale = 2.).cuda()

# Assuming you have a PyTorch tensor named 'tensor_img' with shape [1, 3, 256, 256]
# Convert tensor to a NumPy array
numpy_img = images[1].squeeze(0).permute(1, 2, 0).detach().cpu().numpy()

# Display the image using Matplotlib
plt.imshow(numpy_img)
plt.axis('off')
plt.show()