## Load package

In [1]:

import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
import math
from tqdm import tqdm

import torch
from datasets import load_dataset
from PIL import Image
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ToTensor
from torchvision.transforms.functional import InterpolationMode
import torchvision.transforms as transforms
import torch.nn as nn
import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from PIL import Image

import transformers
from transformers import (
    AutoImageProcessor,
    AutoModel,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import send_example_telemetry,ContextManagers
from transformers.utils.versions import require_version
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.trainer_pt_utils import get_parameter_names
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

import diffusers
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.utils import is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed

if is_wandb_available():
    import wandb

## Initialize sh Args

## Args 

In [2]:
args_list = [
    '--output_dir', './clip-roberta-finetuned',
    '--model_name_or_path', '/remote-home/songtianwei/research/diffusion_model_my/clip-roberta',
    '--data_dir', '/remote-home/songtianwei/research/diffusion_model_my/data',
    '--dataset_name', 'ydshieh/coco_dataset_script',
    '--dataset_config_name', '2017',
    '--image_column', 'image_path',
    '--caption_column', 'caption',
    '--remove_unused_columns', 'False',
    '--do_train',
    '--do_eval',
    '--report_to',"wandb" ,
    '--learning_rate',"5e-5",
    '--warmup_steps',"0" ,
    '--weight_decay',"0.1" ,
    '--overwrite_output_dir' ,
    '--max_seq_length',"77" ,
    '--max_steps',"30000" ,
    '--per_device_train_batch_size',"6" ,
    '--per_device_eval_batch_size',"6" ,
    '--max_train_samples',"5000",
    '--max_eval_samples',"1000" ,
]

## Initilize Arguments

In [3]:
logger = get_logger(__name__, log_level="INFO")

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
# check_min_version("4.32.0.dev0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")


In [4]:

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": (
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
                "with private models)."
            )
        },
    )
    freeze_vision_model: bool = field(
        default=False, metadata={"help": "Whether to freeze the vision model parameters or not."}
    )
    freeze_text_model: bool = field(
        default=False, metadata={"help": "Whether to freeze the text model parameters or not."}
    )


In [5]:

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
    image_column: Optional[str] = field(
        default="image_path",
        metadata={"help": "The name of the column in the datasets containing the full image file paths."},
    )
    caption_column: Optional[str] = field(
        default="caption",
        metadata={"help": "The name of the column in the datasets containing the image captions."},
    )
    train_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a jsonlines file)."}
    )
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input testing data file (a jsonlines file)."},
    )
    max_seq_length: Optional[int] = field(
        default=128,
        metadata={
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    
    # Noise type, default is none, other noise is "random" and "clip_min_noise"
    dataset_noise_type: Optional[str] = field(
        default=None,
        metadata={"help": "The type of noise to add to the dataset."},
    )
    
    dataset_normalize_flag: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to normalize the dataset."},
    )

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension == "json", "`validation_file` should be a json file."


## dataset name mapping

In [6]:
dataset_name_mapping = {
    "image_caption_dataset.py": ("image_path", "caption"),
}


## Transform

In [7]:

class Transform(torch.nn.Module):
    def __init__(self, image_size, mean=None, std=None):
        super().__init__()
        self.transforms = transforms.Compose([
            Resize([image_size], interpolation=InterpolationMode.BICUBIC,antialias=None),
            CenterCrop(image_size),  # CenterCrop is required because Resize doesn't ensure same output size
            # ConvertImageDtype(torch.float),
            ToTensor(), 
        ])
        if mean is not None and std is not None:
            self.transforms.transforms.append(Normalize(mean=mean, std=std))

    def forward(self, x) -> torch.Tensor:
        """`x` should be an instance of `PIL.Image.Image`"""
        with torch.no_grad():
            x = self.transforms(x)
        return x


In [8]:

def normalize_fn(x, mean, std):
    return Normalize(mean=mean, std=std)(x)


## Collate_fn

In [9]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long)
    attention_mask = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long)
    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "return_loss": True,
    }


## Generate

In [10]:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.unet_config = {
            "act_fn": "silu",
            "attention_head_dim": 8,
            "block_out_channels": [
                320,
                640,
                1280,
                1280
            ],
            "center_input_sample": False,
            "cross_attention_dim": 768,  # NOTE 768
            "down_block_types": [
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D"
            ],
            "downsample_padding": 1,
            "flip_sin_to_cos": True,
            "freq_shift": 0,
            "in_channels": 4,
            "layers_per_block": 2,
            "mid_block_scale_factor": 1,
            "norm_eps": 1e-05,
            "norm_num_groups": 32,
            "out_channels": 4,
            "sample_size": 224,
            "up_block_types": [
                "UpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D"
            ]
        }
        self.unet = UNet2DConditionModel(**self.unet_config)
        self.vae_config = {
            'in_channels': 3,
            'out_channels': 3,
            'down_block_types': ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
            'up_block_types': ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
            'block_out_channels': [128, 256, 512, 512],
            'layers_per_block': 2,
            'act_fn': 'silu',
            'latent_channels': 4,
            'norm_num_groups': 32,
            'sample_size': 512,
            'scaling_factor': 0.18215,
        }
        self.vae = AutoencoderKL(**self.vae_config)
        
    def forward(self, img_pixel_values, encoder_hidden_states):
        latent = self.vae.encode(img_pixel_values).latent_dist.sample()
        timesteps = torch.randint(0, 1000, (1,),device=latent.device)
        timesteps = timesteps.long()  #  6
        unet_pred = self.unet(latent, timesteps, encoder_hidden_states).sample
        vae_decoding = self.vae.decoder(unet_pred)
        return vae_decoding
    
    def enable_xformers_memory_efficient_attention(self):
        self.unet.enable_xformers_memory_efficient_attention()
        self.vae.enable_xformers_memory_efficient_attention()


## 1. Parse input arguments
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

In [11]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
    print("1")
    model_args, data_args, training_args = parser.parse_args_into_dataclasses(args=args_list)

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_clip", model_args, data_args)

1


In [12]:
training_args.report_to

['wandb']

## 2. Setup logging

In [13]:
# 2. Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")


08/03/2023 21:25:06 - INFO - __main__ - Training/evaluation parameters TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'fsdp_min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=None,
group_by_l

## 3.Initialize accelerator and distributed training

In [14]:
accelerator_project_config = ProjectConfiguration(total_limit=training_args.save_total_limit)

accelerator = Accelerator(
    gradient_accumulation_steps=training_args.gradient_accumulation_steps,
    mixed_precision="no",
    log_with=training_args.report_to,
    project_config=accelerator_project_config,
)

logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    datasets.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()
    


08/03/2023 21:25:06 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no




## 4. Detecting last checkpoint and eventualy continue from last checkpoint

In [15]:


last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )
    elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(
            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
        )
       

## 5. Load dataset

In [16]:

if data_args.dataset_name is not None:
    # Downloading and loading a dataset from the hub.
    dataset = load_dataset(
        data_args.dataset_name,
        data_args.dataset_config_name,
        cache_dir=model_args.cache_dir,
        keep_in_memory=False,
        data_dir=data_args.data_dir,
        use_auth_token=True if model_args.use_auth_token else None,
    )
else:
    data_files = {}
    if data_args.train_file is not None:
        data_files["train"] = data_args.train_file
        extension = data_args.train_file.split(".")[-1]
    if data_args.validation_file is not None:
        data_files["validation"] = data_args.validation_file
        extension = data_args.validation_file.split(".")[-1]
    if data_args.test_file is not None:
        data_files["test"] = data_args.test_file
        extension = data_args.test_file.split(".")[-1]
    dataset = load_dataset(
        extension,
        data_files=data_files,
        cache_dir=model_args.cache_dir,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    



  0%|          | 0/3 [00:00<?, ?it/s]

## 6. Load pretrained model, tokenizer, and image processor

In [17]:

if model_args.tokenizer_name:
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
    )
elif model_args.model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
    )
else:
    raise ValueError(
        "You are instantiating a new tokenizer from scratch. This is not supported by this script."
        "You can do it from another script, save it, and load it from here, using --tokenizer_name."
    )

In [18]:
revision = None

In [19]:
model_args.model_name_or_path

'/remote-home/songtianwei/research/diffusion_model_my/clip-roberta'

In [20]:
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"

In [21]:
# Note: Do not use the clip tokenizer, loss can not decrease
use_clip_tokenizer = False
if use_clip_tokenizer:
    tokenizer = CLIPTokenizer.from_pretrained(
        pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
    )

In [22]:

# Load image_processor, in this script we only use this to get the mean and std for normalization.
image_processor = AutoImageProcessor.from_pretrained(
    model_args.image_processor_name or model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
    revision=model_args.model_revision,
    use_auth_token=True if model_args.use_auth_token else None,
)


### clip_model

In [23]:
clip_model = AutoModel.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
    revision=model_args.model_revision,
    use_auth_token=True if model_args.use_auth_token else None,
)
clip_model_config = clip_model.config
clip_pretrained = False
if clip_pretrained:
    pass
else:
    clip_model = AutoModel.from_config(clip_model_config)

In [24]:

def _freeze_params(module):
    for param in module.parameters():
        param.requires_grad = False

if model_args.freeze_vision_model:
    _freeze_params(clip_model.vision_model)

if model_args.freeze_text_model:
    _freeze_params(clip_model.text_model)


In [25]:

if training_args.seed is not None:
    set_seed(training_args.seed)
    

### Generator

In [26]:
generator = Generator()

In [27]:
# logger.info(f"generator_train: {generator_train}")

### text_encoder

In [28]:
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
)
# text_encoder
text_encoder.requires_grad_(False)
weight_dtype = torch.float32

## 7. Get the column names for input/target.

In [29]:
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
    column_names = dataset["train"].column_names
elif training_args.do_eval:
    column_names = dataset["validation"].column_names
elif training_args.do_predict:
    column_names = dataset["test"].column_names
else:
    logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")

In [30]:


dataset_columns = dataset_name_mapping.get(data_args.dataset_name, None)
if data_args.image_column is None:
    image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
    image_column = data_args.image_column
    if image_column not in column_names:
        raise ValueError(
            f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
        )
if data_args.caption_column is None:
    caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
    caption_column = data_args.caption_column
    if caption_column not in column_names:
        raise ValueError(
            f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}"
        )


## 8. Preprocessing the datasets.

In [31]:
image_column

'image_path'

In [32]:
image_transformations = Transform(
    clip_model_config.vision_config.image_size
)

### tokenize_captions

In [33]:
def tokenize_captions(examples):
    captions = list(examples[caption_column])
    text_inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True)
    examples["input_ids"] = text_inputs.input_ids
    examples["attention_mask"] = text_inputs.attention_mask
    return examples

### transform_images

In [34]:

def transform_images(examples):
    if isinstance(examples[image_column][0],str):
        # For coco dataset, the images are loaded as path
        # images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]]
        images = [Image.open(image_file).convert("RGB") for image_file in examples[image_column]]
    else:
        # lambdalabs/pokemon-blip-captions
        images = [image.convert("RGB") for image in examples[image_column]]
    examples["pixel_values"] = [image_transformations(image) for image in images]
    return examples

### filter_corrupt_images

In [35]:

def filter_corrupt_images(examples):
    """remove problematic images"""
    valid_images = []
    for image_file in examples[image_column]:
        try:
            Image.open(image_file).convert("RGB") 
            valid_images.append(True)
        except Exception:
            valid_images.append(False)
    return valid_images

### do_train

In [36]:

if training_args.do_train:
    with accelerator.main_process_first():
        if "train" not in dataset:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = dataset["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
        # print(len(train_dataset))
        train_dataset = train_dataset.filter(
            filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
        )
        
        train_dataset = train_dataset.map(
            function=tokenize_captions,
            batched=True,
            remove_columns=[col for col in column_names if col != image_column],
            num_proc=data_args.preprocessing_num_workers,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on train dataset",
        )
    
        # Transform images on the fly as doing it on the whole dataset takes too much time.
        train_dataset.set_transform(transform_images)



In [37]:
# for data in train_dataset:
#     print(data)
#     break

In [38]:
train_dataset

Dataset({
    features: ['image_path', 'input_ids', 'attention_mask'],
    num_rows: 4097
})

In [39]:
train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=False,  # here change to False to check the order of the images
        collate_fn=collate_fn,
        batch_size=training_args.train_batch_size,
        num_workers=training_args.dataloader_num_workers,
        drop_last=True,
    )

In [40]:
# for idx,batch in enumerate(train_dataloader):
#     if idx>4:
#         print(batch)
#     if(idx>7):
#         break

### do_eval

In [41]:

if training_args.do_eval:
    with accelerator.main_process_first():
        if "validation" not in dataset:
            raise ValueError("--do_eval requires a train validation")
        eval_dataset = dataset["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
    
        eval_dataset = eval_dataset.filter(
            filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
        )
        eval_dataset = eval_dataset.map(
            function=tokenize_captions,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=[col for col in column_names if col != image_column],
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on validation dataset",
        )
    
        # Transform images on the fly as doing it on the whole dataset takes too much time.
        eval_dataset.set_transform(transform_images)




In [42]:
# evaluation dataloader
eval_dataloader = torch.utils.data.DataLoader(
    eval_dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=training_args.eval_batch_size,
    num_workers=training_args.dataloader_num_workers,
    drop_last=True,
)


### do_predict

In [43]:

if training_args.do_predict:
    with accelerator.main_process_first():
        if "test" not in dataset:
            raise ValueError("--do_predict requires a test dataset")
        test_dataset = dataset["test"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(test_dataset), data_args.max_eval_samples)
            test_dataset = test_dataset.select(range(max_eval_samples))
    
        test_dataset = test_dataset.filter(
            filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
        )
        test_dataset = test_dataset.map(
            function=tokenize_captions,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=[col for col in column_names if col != image_column],
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on test dataset",
        )
    
        # Transform images on the fly as doing it on the whole dataset takes too much time.
        test_dataset.set_transform(transform_images)
        

### Normailze fn

In [44]:
def normalize_fn(x, mean, std):
    return transforms.Normalize(mean=mean, std=std)(x)

## 9.Initialize the optimizer

In [45]:
use_8bit_adam = True
if use_8bit_adam:
    try:
        import bitsandbytes as bnb
    except ImportError:
        raise ImportError(
            "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
        )

    optimizer_cls = bnb.optim.AdamW8bit
else:
    optimizer_cls = torch.optim.AdamW
    

Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)



Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /remote-home/songtianwei/conda/envs/pytorch2/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda113.so
CUDA SETUP: CUDA runtime path found: /remote-home/songtianwei/conda/envs/pytorch2/lib/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 113
CUDA SETUP: Loading binary /remote-home/songtianwei/conda/envs/pytorch2/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...


In [46]:
optimizer_cls

bitsandbytes.optim.adamw.AdamW8bit

In [47]:
decay_parameters = get_parameter_names(clip_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in clip_model.named_parameters() if (n in decay_parameters and p.requires_grad)
            ],
            "weight_decay": 0.1,
        },
        {
            "params": [
                p for n, p in clip_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
            ],
            "weight_decay": 0.0,
        },
    ]
adam_kwargs = {
    "lr": 5e-5,
    "betas": (0.9, 0.999),
    "eps": 1e-8,
}

optimizer = optimizer_cls(optimizer_grouped_parameters, **adam_kwargs)
optimizer_generator = optimizer_cls(generator.parameters(), **adam_kwargs)

In [48]:
decay_parameters

['vision_model.vision_model.embeddings.patch_embedding.weight',
 'vision_model.vision_model.embeddings.position_embedding.weight',
 'vision_model.vision_model.embeddings.class_embedding',
 'vision_model.vision_model.encoder.layers.0.self_attn.k_proj.weight',
 'vision_model.vision_model.encoder.layers.0.self_attn.v_proj.weight',
 'vision_model.vision_model.encoder.layers.0.self_attn.q_proj.weight',
 'vision_model.vision_model.encoder.layers.0.self_attn.out_proj.weight',
 'vision_model.vision_model.encoder.layers.0.mlp.fc1.weight',
 'vision_model.vision_model.encoder.layers.0.mlp.fc2.weight',
 'vision_model.vision_model.encoder.layers.1.self_attn.k_proj.weight',
 'vision_model.vision_model.encoder.layers.1.self_attn.v_proj.weight',
 'vision_model.vision_model.encoder.layers.1.self_attn.q_proj.weight',
 'vision_model.vision_model.encoder.layers.1.self_attn.out_proj.weight',
 'vision_model.vision_model.encoder.layers.1.mlp.fc1.weight',
 'vision_model.vision_model.encoder.layers.1.mlp.fc2.w

In [49]:
lr_scheduler = 'linear'
lr_warmup_steps = 0

In [50]:
# lr_scheduler = get_scheduler(
#         lr_scheduler,
#         optimizer=optimizer,
#         num_warmup_steps=lr_warmup_steps * training_args.gradient_accumulation_steps,
#         num_training_steps=training_args.max_steps * training_args.gradient_accumulation_steps,
#     )

In [51]:
# ToDo Optimizer for generator

## 10.Initial About the accelerate

In [52]:
# Handle the repository creation
if accelerator.is_main_process:
    if training_args.output_dir is not None:
        os.makedirs(training_args.output_dir, exist_ok=True)


In [53]:

# For optimizer and scheduler
optimizer = accelerator.prepare(optimizer)
# lr_scheduler = accelerator.prepare(lr_scheduler)


### load model and optimizer and dataloader to accelerate

In [54]:
# For model
add_noise = True
clip_model = accelerator.prepare(clip_model)
if add_noise:
    generator = accelerator.prepare(generator)
train_dataloader = accelerator.prepare(train_dataloader)
eval_dataloader = accelerator.prepare(eval_dataloader)
text_encoder.to(accelerator.device, dtype=weight_dtype)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

## 11.Initialize max_train_step and tracker ( wandb start)

In [55]:
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
if training_args.max_steps is None or training_args.max_steps <= 0:
    training_args.max_steps = training_args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True
training_args.max_steps = (int)(training_args.max_steps)
training_args.num_train_epochs = (int)(training_args.num_train_epochs)

In [56]:
training_args.num_train_epochs

3

In [57]:
num_update_steps_per_epoch

682

In [58]:
tracker_project_name = "text2image-fine-tune"

In [59]:
# # We need to initialize the trackers we use, and also store our configuration.
# # The trackers initializes automatically on the main process.
# if accelerator.is_main_process:
#     tracker_config = dict(vars(args))
#     # tracker_config.pop("validation_prompts")
#     accelerator.init_trackers(tracker_project_name, tracker_config)

## 12. Start Train

### log training info

In [60]:
total_batch_size = training_args.train_batch_size * accelerator.num_processes * training_args.gradient_accumulation_steps

logger.info("***** Running training *****")
if training_args.do_train:
    logger.info(f"  Training num examples = {len(train_dataset)}")
if training_args.do_eval:
    logger.info(f"  Evaluation num examples = {len(eval_dataset)}")
logger.info(f"  Num Epochs = {training_args.num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {training_args.train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {training_args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {training_args.max_steps}")
global_step = 0
first_epoch = 0

08/03/2023 21:25:44 - INFO - __main__ - ***** Running training *****
08/03/2023 21:25:44 - INFO - __main__ -   Training num examples = 4097
08/03/2023 21:25:44 - INFO - __main__ -   Evaluation num examples = 1000
08/03/2023 21:25:44 - INFO - __main__ -   Num Epochs = 3
08/03/2023 21:25:44 - INFO - __main__ -   Instantaneous batch size per device = 6
08/03/2023 21:25:44 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 6
08/03/2023 21:25:44 - INFO - __main__ -   Gradient Accumulation steps = 1
08/03/2023 21:25:44 - INFO - __main__ -   Total optimization steps = 30000


### resume from checkpoint

In [61]:
if training_args.resume_from_checkpoint:
    if training_args.resume_from_checkpoint != "latest":
        path = os.path.basename(training_args.resume_from_checkpoint)
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(training_args.output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

    if path is None:
        accelerator.print(
            f"Checkpoint '{training_args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        training_args.resume_from_checkpoint = None
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(training_args.output_dir, path))
        global_step = int(path.split("-")[1])

        resume_global_step = global_step * training_args.gradient_accumulation_steps
        first_epoch = global_step // num_update_steps_per_epoch
        resume_step = resume_global_step % (num_update_steps_per_epoch * training_args.gradient_accumulation_steps)
 

In [62]:
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, training_args.max_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Training Steps")

Training Steps:   0%|                                                                                                             | 0/30000 [00:00<?, ?it/s]

In [63]:
accelerator.free_memory()

### Training Circle

In [73]:

add_noise = False

generator_train = False
clip_train = True

if add_noise:
    if generator_train:
        generator.train()
        generator.requires_grad_(True)
        generator.zero_grad()
    else:
        generator.eval()
        generator.requires_grad_(False)
else:
    pass
    
if clip_train:
    clip_model.train()
    clip_model.requires_grad_(True)
else:
    clip_model.eval()
    clip_model.requires_grad_(False)
clip_model.zero_grad()
    
use_normailize = True

logger.info("clip_train: {}, generator_train: {}".format(clip_train, generator_train))
logger.info(f"add_noise: {add_noise}, use_normailize: {use_normailize}")

08/03/2023 21:30:36 - INFO - __main__ - clip_train: True, generator_train: False
08/03/2023 21:30:36 - INFO - __main__ - add_noise: False, use_normailize: True


In [74]:
# for step, batch in enumerate(train_dataloader):
#     text_encoder = clip_model.text_model
#     batch_pixel_values = batch["pixel_values"]  # [6,3,224,224]
#     batch_input_ids = batch["input_ids"]
#     batch_attention_mask = batch["attention_mask"]
#     print(batch_input_ids.shape)  # [64,128]
#     with torch.no_grad():
#         encoder_hidden_states = text_encoder(batch_input_ids,batch_attention_mask)[0]  # [6,128,768]     
#     print(encoder_hidden_states.shape)
#     noise = generator(batch_pixel_values, encoder_hidden_states)
#     break

In [75]:
add_noise = True

In [77]:

for epoch in range(first_epoch, training_args.num_train_epochs):
    if training_args.do_train:
        logging.info("*"*50)
        logging.info("Doing Training")
        logging.info("*"*50)
            
        progress_bar.set_description("Training Steps")
        train_loss = 0.0

        generator_step_M = 1
        clip_step_N = 1
        train_target_list = ["generator"]*generator_step_M + ["clip"]*clip_step_N
        cur_index = 0
        for step, batch in enumerate(train_dataloader):
            # Skip steps until we reach the resumed step
            if training_args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                if step % training_args.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue
            # which to train
            train_target = train_target_list[cur_index]
            cur_index = (cur_index + 1) % len(train_target_list)
            if train_target == "generator":
                pass

            batch_pixel_values = batch["pixel_values"]  # [6,3,224,224]
            
            batch_input_ids = batch["input_ids"]
            batch_attention_mask = batch["attention_mask"]
            print(batch_pixel_values.shape)
            
            if add_noise:
                text_encoder = clip_model.text_model
                with torch.no_grad():
                    encoder_hidden_states = text_encoder(batch_input_ids,batch_attention_mask)[0]  # [6,128,768]                
                noise = generator(batch_pixel_values, encoder_hidden_states)
                
                # limit the norm of the noise
                norm_type = 'l2'
                epsilon = 16
                if norm_type == 'l2':
                    temp = torch.norm(noise.view(noise.shape[0], -1), dim=1).view(-1, 1, 1, 1)
                    noise = noise * epsilon / temp
                else:
                    noise = torch.clamp(noise, -epsilon / 255, epsilon / 255)
                image = batch_pixel_values + noise 
                image = torch.clamp(image, -1, 1)
            else:
                image = batch_pixel_values
                 
            if use_normailize:
                image = normalize_fn(image, mean=image_processor.image_mean, std=image_processor.image_std)
                
            batch_data_input = {
                "input_ids":batch_input_ids,
                "pixel_values" : image,
                "attention_mask":batch["attention_mask"],
                "return_loss": True
            }
            output = clip_model(**batch_data_input)
            logits_per_image = output.logits_per_image   # for training , image_logits is the same as logits text
            logits_per_text = output.logits_per_text
            
            loss = output.loss

            # Gather the losses across all processes for logging (if we use distributed training).
            avg_loss = accelerator.gather(loss.repeat(training_args.train_batch_size)).mean()
            train_loss += avg_loss.item() / training_args.gradient_accumulation_steps
            print(avg_loss)
            # Backpropagate
            accelerator.backward(loss)

            if accelerator.sync_gradients:
                if generator_train:
                    accelerator.clip_grad_norm_(generator.parameters(), training_args.max_grad_norm)
                elif clip_train:
                    accelerator.clip_grad_norm_(clip_model.parameters(), training_args.max_grad_norm)
            
            # Update optimizer
            optimizer.step()
            optimizer_generator.step()
            # lr_scheduler.step()
            
            clip_model.zero_grad()
            generator.zero_grad()
            # optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                # accelerator.log({"train_loss": train_loss}, step=global_step)
                train_loss = 0.0

                checkpointing_steps = 100
                if global_step % checkpointing_steps == 0:
                    logging.info("Epoch : {} ; Step : {} ; Save checkpoint to {}".format(epoch, global_step, training_args.output_dir))
                    if accelerator.is_main_process:
                        save_path = os.path.join(training_args.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")
            
            record = {
                    "epoch": epoch,
                    "step": step,
                    "global_step":global_step,
                    "train_loss": loss.detach().item(),
                    "lr": optimizer.param_groups[0]["lr"],
                    }
            # wandb.log(record)  
            progress_bar.set_postfix(**record)

            if global_step >= training_args.max_steps:
                break

08/03/2023 21:31:06 - INFO - root - **************************************************
08/03/2023 21:31:06 - INFO - root - Doing Training
08/03/2023 21:31:06 - INFO - root - **************************************************
Training Steps:   1%|▎                                   | 245/30000 [05:21<4:21:38,  1.90it/s, epoch=0, global_step=245, lr=5e-5, step=24, train_loss=1.79]Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])
tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Training Steps:   1%|▎                                   | 246/30000 [05:22<39:22:00,  4.76s/it, epoch=0, global_step=246, lr=5e-5, step=0, train_loss=1.79]Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 247/30000 [05:22<28:54:35,  3.50s/it, epoch=0, global_step=247, lr=5e-5, step=1, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 248/30000 [05:23<21:36:08,  2.61s/it, epoch=0, global_step=248, lr=5e-5, step=2, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 249/30000 [05:23<16:20:51,  1.98s/it, epoch=0, global_step=249, lr=5e-5, step=3, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 250/30000 [05:24<12:59:23,  1.57s/it, epoch=0, global_step=250, lr=5e-5, step=4, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 251/30000 [05:25<10:33:03,  1.28s/it, epoch=0, global_step=251, lr=5e-5, step=5, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                    | 252/30000 [05:25<8:51:12,  1.07s/it, epoch=0, global_step=252, lr=5e-5, step=6, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                    | 253/30000 [05:26<7:38:50,  1.08it/s, epoch=0, global_step=253, lr=5e-5, step=7, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                    | 254/30000 [05:26<6:37:20,  1.25it/s, epoch=0, global_step=254, lr=5e-5, step=8, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                    | 255/30000 [05:27<6:13:44,  1.33it/s, epoch=0, global_step=255, lr=5e-5, step=9, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 256/30000 [05:28<5:51:01,  1.41it/s, epoch=0, global_step=256, lr=5e-5, step=10, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 257/30000 [05:28<5:33:32,  1.49it/s, epoch=0, global_step=257, lr=5e-5, step=11, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 258/30000 [05:29<5:20:26,  1.55it/s, epoch=0, global_step=258, lr=5e-5, step=12, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 259/30000 [05:29<5:20:30,  1.55it/s, epoch=0, global_step=259, lr=5e-5, step=13, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)
torch.Size([6, 3, 224, 224])


Forward upsample size to force interpolation output size.
Training Steps:   1%|▎                                   | 260/30000 [05:30<4:56:06,  1.67it/s, epoch=0, global_step=260, lr=5e-5, step=14, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 261/30000 [05:30<4:52:55,  1.69it/s, epoch=0, global_step=261, lr=5e-5, step=15, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 262/30000 [05:31<4:55:41,  1.68it/s, epoch=0, global_step=262, lr=5e-5, step=16, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 263/30000 [05:32<4:52:57,  1.69it/s, epoch=0, global_step=263, lr=5e-5, step=17, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 264/30000 [05:32<4:51:45,  1.70it/s, epoch=0, global_step=264, lr=5e-5, step=18, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 265/30000 [05:33<4:49:37,  1.71it/s, epoch=0, global_step=265, lr=5e-5, step=19, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 266/30000 [05:33<4:50:54,  1.70it/s, epoch=0, global_step=266, lr=5e-5, step=20, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 267/30000 [05:34<4:48:48,  1.72it/s, epoch=0, global_step=267, lr=5e-5, step=21, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 268/30000 [05:35<4:58:14,  1.66it/s, epoch=0, global_step=268, lr=5e-5, step=22, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 269/30000 [05:35<4:57:07,  1.67it/s, epoch=0, global_step=269, lr=5e-5, step=23, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 270/30000 [05:36<4:56:49,  1.67it/s, epoch=0, global_step=270, lr=5e-5, step=24, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 271/30000 [05:36<4:49:33,  1.71it/s, epoch=0, global_step=271, lr=5e-5, step=25, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 272/30000 [05:37<4:49:51,  1.71it/s, epoch=0, global_step=272, lr=5e-5, step=26, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 273/30000 [05:38<4:51:18,  1.70it/s, epoch=0, global_step=273, lr=5e-5, step=27, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 274/30000 [05:38<4:56:34,  1.67it/s, epoch=0, global_step=274, lr=5e-5, step=28, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 275/30000 [05:39<4:57:55,  1.66it/s, epoch=0, global_step=275, lr=5e-5, step=29, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 276/30000 [05:39<4:58:11,  1.66it/s, epoch=0, global_step=276, lr=5e-5, step=30, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


Forward upsample size to force interpolation output size.


torch.Size([6, 3, 224, 224])


Training Steps:   1%|▎                                   | 277/30000 [05:40<4:47:53,  1.72it/s, epoch=0, global_step=277, lr=5e-5, step=31, train_loss=1.79]

tensor(1.7918, device='cuda:0', grad_fn=<MeanBackward0>)


KeyboardInterrupt: 

In [None]:
# accelerator.end_training()