Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train on custom input -> noise as output #315

Closed
Felix-FN opened this issue Mar 28, 2024 · 9 comments
Closed

Train on custom input -> noise as output #315

Felix-FN opened this issue Mar 28, 2024 · 9 comments

Comments

@Felix-FN
Copy link

Felix-FN commented Mar 28, 2024

Hey there,
I'm fairly new to Machine learning in general, but need to train a Txt2Img-model on a custom input.
I've seen this implementation and it seemed like a relatively "small effort" to make this model work... I'm not sure about this anymore

I got the setup script from the readMe, which loads my own Dataset.
Are there any Clues on how to make this work or can you provide me with a functional training script, inference Script?
Any help is highly appreciated!

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
from torchvision.transforms import ToPILImage, ToTensor
from torchvision import transforms
from PIL import Image
import pickle
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

num_epochs = 10
batch_size = 2
device = torch.device("cuda")

pic_idx = 0
clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 4500,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
    use_all_token_embeds = True,            # whether to use fine-grained contrastive learning (FILIP)
    decoupled_contrastive_learning = True,  # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
    extra_latent_projection = True,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
    use_visual_ssl = True,                  # whether to do self supervised learning on images
    visual_ssl_type = 'simclr',             # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
    use_mlm = False,                        # use masked language learning (MLM) on text (DeCLIP)
    text_ssl_loss_weight = 0.05,            # weight for text MLM loss
    image_ssl_loss_weight = 0.05            # weight for image self-supervised learning loss
).cuda()
 
""" Get my own Dataset into the Model """
def xosc2ImageDataset(): 
    """Dataset is local and looks like: 
        dataDic = {
            "train":{
                "ImgPath": [], "caption": []}, # First Caption belongs to first image and so on
            "num_rows":{},
            'validation':{
                "ImgPath": [], "caption": []}}"""
    with open('saved_dictionary.pkl', 'rb') as f:
        loaded_dict = pickle.load(f)
    dset = loaded_dict
    return dset

class TextDataset:
    def __init__(self, texts, batch_size=4, max_length=4500):
        self.texts = texts
        self.batch_size = batch_size
        self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __iter__(self):
        self.current_index = 0  # Setze den Index zu Beginn der Iteration zurück
        return self
    def __next__(self):
        batch_texts = []
        for _ in range(self.batch_size):
            if self.current_index >= len(self.texts):
                raise StopIteration
            text = self.texts[self.current_index]
            text = text[:self.max_length]
            tensor = torch.tensor([ord(char) for char in text])
            batch_texts.append(tensor)
            self.current_index += 1
        padded_tensors = pad_sequence(batch_texts, batch_first=True)
        return padded_tensors

class ImageDataset:
    def __init__(self, image_paths, batch_size=4, image_size=(256, 256)):
        self.image_paths = image_paths
        self.batch_size = batch_size
        self.image_size = image_size
    def __len__(self):
        return len(self.image_paths)
    def __iter__(self):
        self.current_index = 0  # Setze den Index zu Beginn der Iteration zurück
        return self
    def __next__(self):
        batch_images = []
        for _ in range(self.batch_size):
            if self.current_index >= len(self.image_paths):
                raise StopIteration
            path = self.image_paths[self.current_index]
            normalized_path = path.replace('\\', '/')
            image = self.load_image(normalized_path)
            batch_images.append(image)
            self.current_index += 1
        return torch.stack(batch_images, dim=0)
    def load_image(self, path):
        transform = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
        ])
        image = Image.open(path).convert("RGB")
        image = transform(image)
        return image

    # mock data
    #text = torch.randint(0, 49408, (4, 256)).cuda()
    #images = torch.randn(4, 3, 256, 256).cuda()

text_list = xosc2ImageDataset()[f"train"]['caption'] # list of captions
image_list = xosc2ImageDataset()[f"train"]['image_url']
val_text_list = xosc2ImageDataset()["validation"]['caption']
val_image_list = xosc2ImageDataset()["validation"]['image_url']

text_dataset = TextDataset(text_list, batch_size=batch_size)
image_dataset = ImageDataset(image_list, batch_size=batch_size)

        # Calc the numb of total batches
num_batches = len(text_dataset) // batch_size  

for epoch in range(num_epochs): # train
    text_loader = iter(text_dataset)
    image_loader = iter(image_dataset)
    for _ in range(num_batches):  # Iteriere üver count of batches in Dataset
        batch_texts = next(text_loader)
        batch_images = next(image_loader)
        loss = clip(
            batch_texts.to(device),
            batch_images.to(device),
            return_loss=True
        )
        print("Clip Epoch:", epoch, "Loss:", loss.item())
        loss.backward()
        # Update der Parameter des Modells mit dem Optimizer

"""for epoch in range(num_epochs): # train
    for batch_texts, batch_images in zip(text_dataset, image_dataset):
        loss = clip(
            batch_texts, #text,
            batch_images, #images,
            return_loss = True
        )
        print("Clip Epoch:", epoch)
        loss.backward()"""
    # do above for many steps ...

"""prior networks (with transformer)"""  #setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(# diffusion prior network, which contains the CLIP and network (with transformer) above
    net = prior_network,
    clip = clip,
    timesteps = 1000,
    sample_timesteps = 64,
    cond_drop_prob = 0.2
).cuda()
"""for epoch in range(num_epochs):
    for batch_texts, batch_images in zip(text_dataset, image_dataset):
        loss = diffusion_prior(batch_texts,batch_images)#(text, images)
        print("Prior Epoch:", epoch)
        loss.backward()"""

for epoch in range(num_epochs):
    text_loader = iter(text_dataset)
    image_loader = iter(image_dataset)
    for _ in range(num_batches):
        batch_texts = next(text_loader)
        batch_images = next(image_loader)
        loss = diffusion_prior(
            batch_texts.to(device), 
            batch_images.to(device)
            )
        print("Prior Epoch:", epoch, "Loss:", loss.item())
        loss.backward()
        # Update der Parameter des Modells mit dem Optimizer
    # do above for many steps ...


""" decoder (with unet)"""
unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    text_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8),
    cond_on_text_encodings = True    # set to True for any unets that need to be conditioned on text encodings
).cuda()
unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),      # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
    image_sizes = (128, 256),   # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
"""for epoch in range(num_epochs):
    for unet_number in (1, 2):
        for batch_texts, batch_images in zip(text_dataset, image_dataset):
            loss = decoder(batch_images, text = batch_texts, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
            print("Decoder Epoch:", epoch, "unet_number: ", unet_number)
            loss.backward()"""
for epoch in range(num_epochs):
    text_loader = iter(text_dataset)
    image_loader = iter(image_dataset)
    for _ in range(num_batches):
        batch_texts = next(text_loader)
        batch_images = next(image_loader)
        for unet_number in (1, 2):
            loss = decoder(batch_images.to(device)
                           , text=batch_texts.to(device)
                           , unet_number=unet_number
                           )
            print("Decoder Epoch:", epoch, "unet_number:", unet_number, "Loss:", loss.item())
            loss.backward()
            # Update der Parameter des Modells mit dem Optimizer

"""Validation"""
# Text- und Bild-Datensätze für die Validierung erstellen
val_text_dataset = TextDataset(val_text_list, batch_size=batch_size)
val_image_dataset = ImageDataset(val_image_list, batch_size=batch_size)
clip.eval()
diffusion_prior.eval()
decoder.eval()
# Verfolgen des Validierungsloss
val_clip_loss = 0.0
val_prior_loss = 0.0
val_decoder_loss = 0.0
# Iteration über die Validierungsdaten
with torch.no_grad():
    for batch_texts, batch_images in zip(val_text_dataset, val_image_dataset):
        # Clip Loss berechnen
        clip_loss = clip(batch_texts, batch_images, return_loss=True)
        val_clip_loss += clip_loss.item()

        # Prior Loss berechnen
        prior_loss = diffusion_prior(batch_texts, batch_images)
        val_prior_loss += prior_loss.item()

        # Decoder Loss berechnen
        decoder_loss = decoder(batch_images, text=batch_texts, unet_number=1)  
        val_decoder_loss += decoder_loss.item()

# Durchschnittliche Loss-Werte berechnen
num_batches = len(val_text_dataset)
val_clip_loss /= num_batches
val_prior_loss /= num_batches
val_decoder_loss /= num_batches

# Ausgabe der durchschnittlichen Loss-Werte
print("Validation Clip Loss:", val_clip_loss)
print("Validation Prior Loss:", val_prior_loss)
print("Validation Decoder Loss:", val_decoder_loss)

""" Generating Images """
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(['<VERY LONG INPUT>], cond_scale = 2.) # classifier free guidance strength (> 1 would strengthen the condition)

for img in images:
    img = ToPILImage()(img)
    img.show()
    img.save("./dalle2_demo/20Epochs"+str(pic_idx)+".png", format="PNG")
    pic_idx += 1
    # save your image (in this example, of size 256x256)
@Felix-FN
Copy link
Author

Felix-FN commented Apr 2, 2024

Update:
Got it running, but receive pure noise as output. Any tips?

@Felix-FN Felix-FN changed the title Train on custom input Train on custom input -> noise as output Apr 2, 2024
@xiaomei2002
Copy link

hey, did you get the correct output without noise? I still dont know how to run the code, would you please provide your train code? thanks!!

@Felix-FN
Copy link
Author

Hello, yes.

I'm still an amateur, but I got it working.

Attached is basically my whole model. There are definitely a few inconsistencies, but after I put DallE together, it produces images as I envisioned them.

You can read about how to "assemble" DallE here: #282

You load the trainer checkpoints as in the code I provided and add a ["model"] after it as follows: torch.load(checkpoint_path)['model']

Thats it.

Model.zip

@xiaomei2002
Copy link

Thank you very much! It helps me a lot!!
By the way, how many images does your dataset have? How many epochs does it take to get a good result?
I was a complete newcomer and had no idea about any of this.
It would be great if you could answer me!!^ ^

@Felix-FN
Copy link
Author

Yeah no problem.
I have nearly 200k images to train from.

Because of that size it took nearly 24h for 3 epochs on a A100 GPU.

It took maybe 5-8 epochs each to see some quite good results on that low resolution.
If you want a batter resolution you maybe need to add a second UNet to enhance it.

@xiaomei2002
Copy link

Thank you very much for your answer, which gave me a great help!!! :)

@Felix-FN
Copy link
Author

FIY: had an Error in the Priortrainer. You need to insert text_loader = iter(text_dataset) image_loader = iter(image_dataset) one loop above. After: for epoch in range(num_epochs):

Otherwise you would just train your Prior from one image the whole time...

@u1ug
Copy link

u1ug commented Aug 12, 2024

FIY: had an Error in the Priortrainer. You need to insert text_loader = iter(text_dataset) image_loader = iter(image_dataset) one loop above. After: for epoch in range(num_epochs):

Otherwise you would just train your Prior from one image the whole time...

Usually dataset class does not handle batching, shuffling, etc. You use torch dataloader class which takes care of dataset iterating, parallel data loading, batching and etc.

@LiamLiu62
Copy link

是的,没问题。 我有近 20 万张图片可供训练。

由于尺寸较大,在 A100 GPU 上进行 3 个 epoch 需要花费将近 24 小时。

每次可能需要 5-8 个时期才能在低分辨率下看到相当好的结果。 如果您想要更好的分辨率,您可能需要添加第二个 UNet 来增强它。

Thanks for your scripts, and I want to know if you use the OpenCLIP or train a CLIP model from scratch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants