Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding more "hint" to training process #271

Open
orydatadudes opened this issue Mar 14, 2023 · 14 comments
Open

adding more "hint" to training process #271

orydatadudes opened this issue Mar 14, 2023 · 14 comments

Comments

@orydatadudes
Copy link

Hi,
i was focusing with the human posture task (getting posture from openpose image + prompt and than generating the charter under the right pose - control_sd15_openpose.pth)

However, i wanted to add one more hint to force the controlnet to generate specific human:
so if in the original code the hint be an posture image like that :

v2-c5e272899550ac318ed4732336fd7c82_720w

i would like to add more image of the specific human:

MEN_Denim_id_00000080_0_01_7_additional

the target should be that image of that person, under the new posture

so what i did is:

  1. in the dataset file: reading that extra image too, concatenate in the channel dimension, that image with the posture image so now the
    source variable is 6 channels not 3

     # concate source and source image
     source = np.concatenate([source,source_image],axis=2)
    
     return dict(jpg=target, txt=prompt, hint=source)
    
  2. changing the yaml config file to support 6 channels - NOT SURE I REALLY UNDERSTATED THE MEANING OF THESE VALUES

model:
target: cldm.cldm.ControlLDM
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
control_key: "hint"
image_size: 64
channels: was 4 i changed to 7
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
only_mid_control: False

control_stage_config:
  target: cldm.cldm.ControlNet
  params:
    image_size: 32 # unused
    **in_channels:  was 4 i changed to 7**
    **hint_channels: was 3 i changed to 6** 
    model_channels: 320
    attention_resolutions: [ 4, 2, 1 ]
    num_res_blocks: 2
    channel_mult: [ 1, 2, 4, 4 ]
    num_heads: 8
    use_spatial_transformer: True
    transformer_depth: 1
    context_dim: 768
    use_checkpoint: True
    legacy: False

unet_config:
  target: cldm.cldm.ControlledUnetModel
  params:
    image_size: 32 # unused
    **in_channels:  was 4 i changed to 7**
    **out_channels:  was 4 i changed to 7**
    model_channels: 320
    attention_resolutions: [ 4, 2, 1 ]
    num_res_blocks: 2
    channel_mult: [ 1, 2, 4, 4 ]
    num_heads: 8
    use_spatial_transformer: True
    transformer_depth: 1
    context_dim: 768
    use_checkpoint: True
    legacy: False

first_stage_config:
  target: ldm.models.autoencoder.AutoencoderKL
  params:
    **embed_dim:  was 4 i changed to 7**
    monitor: val/rec_loss
    ddconfig:
      double_z: true
      **z_channels:  was 4 i changed to 7**
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 4
      - 4
      num_res_blocks: 2
      attn_resolutions: []
      dropout: 0.0
    lossconfig:
      target: torch.nn.Identity

cond_stage_config:
  target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

the problem is when i trained the model from scratch - running tutorial_train.py with resume_path = None
the model predictions, the reconstruction and the samples that locate under image_log->train folder are just a noise

does anyone have any idea how to solve that ?
thanks

@liuzhihui2046
Copy link

I don't think you need to start training from scratch. If you finetuning from some trained model, you should be able to converge faster

@fkcptlst
Copy link

fkcptlst commented Apr 9, 2023

To my knowledge, you should probably alter hint_channels under control_stage_config instead of channels

@SnowdenLee
Copy link

Hi, do you know why use_ema is set to False?

@Yuhyeong
Copy link

hello!

I'm doing similar attempt as u do.

do u have any further results?

@MuyuenLP
Copy link

MuyuenLP commented Jan 3, 2024

hello!

I'm doing similar attempt as u do.

do u have any further results?

Hello!

Have you solved the problem? I wonder if I could learn from your work.

Thank you

@Yuhyeong
Copy link

Yuhyeong commented Jan 3, 2024

hello!
I'm doing similar attempt as u do.
do u have any further results?

Hello!

Have you solved the problem? I wonder if I could learn from your work.

Thank you

I finshed my works months ago, it works but not significantly effective.

In the config part, i only changed hint_channels to 6

Then I merged 2 3channels img into a 6channels img, and save as tiff, create a customized dataset object for training. this is my dataset code below.

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/pose+face/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']


        source = tifffile.imread(os.path.join('./training/pose+face/source', source_filename))
        target = cv2.imread(os.path.join('./training/pose+face/target', target_filename))

        # Do not forget that OpenCV read images in BGR order.
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

@MuyuenLP
Copy link

MuyuenLP commented Jan 4, 2024

hello!
I'm doing similar attempt as u do.
do u have any further results?

Hello!
Have you solved the problem? I wonder if I could learn from your work.
Thank you

I finshed my works months ago, it works but not significantly effective.

In the config part, i only changed hint_channels to 6

Then I merged 2 3channels img into a 6channels img, and save as tiff, create a customized dataset object for training. this is my dataset code below.

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/pose+face/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']


        source = tifffile.imread(os.path.join('./training/pose+face/source', source_filename))
        target = cv2.imread(os.path.join('./training/pose+face/target', target_filename))

        # Do not forget that OpenCV read images in BGR order.
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

Thanks for your prompt reply, I am trying to write as you say. I failed to save the merged img as tiff, so I
use numpy.concatenate to merge two imgs, like

    stacked_array = np.concatenate((inpaint_resize, ref_image), axis=2)
    #inpaint_resize : (512,512,3)
    #ref_image : (512,512,3)

Then I get a (512,512,6) numpy as a hint. But there is something wrong.

  File "/root/autodl-tmp/ControlNet-v1-1-nightly/cldm/logger.py", line 40, in log_local
    Image.fromarray(grid).save(path)
  File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, in fromarray
    raise TypeError(msg) from e
TypeError: Cannot handle this data type: (1, 1, 6), |u1

I am trying to fix this problem. May I ask if you have done any operation other than modifying the hint_channels, or can you provide the part that you save tiff? I would be very, very grateful.

@Yuhyeong
Copy link

Yuhyeong commented Jan 5, 2024

hello!
I'm doing similar attempt as u do.
do u have any further results?

Hello!
Have you solved the problem? I wonder if I could learn from your work.
Thank you

I finshed my works months ago, it works but not significantly effective.
In the config part, i only changed hint_channels to 6
Then I merged 2 3channels img into a 6channels img, and save as tiff, create a customized dataset object for training. this is my dataset code below.

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/pose+face/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']


        source = tifffile.imread(os.path.join('./training/pose+face/source', source_filename))
        target = cv2.imread(os.path.join('./training/pose+face/target', target_filename))

        # Do not forget that OpenCV read images in BGR order.
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

Thanks for your prompt reply, I am trying to write as you say. I failed to save the merged img as tiff, so I use numpy.concatenate to merge two imgs, like

    stacked_array = np.concatenate((inpaint_resize, ref_image), axis=2)
    #inpaint_resize : (512,512,3)
    #ref_image : (512,512,3)

Then I get a (512,512,6) numpy as a hint. But there is something wrong.

  File "/root/autodl-tmp/ControlNet-v1-1-nightly/cldm/logger.py", line 40, in log_local
    Image.fromarray(grid).save(path)
  File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, in fromarray
    raise TypeError(msg) from e
TypeError: Cannot handle this data type: (1, 1, 6), |u1

I am trying to fix this problem. May I ask if you have done any operation other than modifying the hint_channels, or can you provide the part that you save tiff? I would be very, very grateful.

it is a meanless error, just overlook it for is aims to record the image log while training.

            if grid.shape[2] == 6:
                grid = grid[ :, :,:3]
                continue

add this code before File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, to skip it in cldm/logger.py

@MuyuenLP
Copy link

MuyuenLP commented Jan 6, 2024

hello!
I'm doing similar attempt as u do.
do u have any further results?

Hello!
Have you solved the problem? I wonder if I could learn from your work.
Thank you

I finshed my works months ago, it works but not significantly effective.
In the config part, i only changed hint_channels to 6
Then I merged 2 3channels img into a 6channels img, and save as tiff, create a customized dataset object for training. this is my dataset code below.

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/pose+face/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']


        source = tifffile.imread(os.path.join('./training/pose+face/source', source_filename))
        target = cv2.imread(os.path.join('./training/pose+face/target', target_filename))

        # Do not forget that OpenCV read images in BGR order.
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

Thanks for your prompt reply, I am trying to write as you say. I failed to save the merged img as tiff, so I use numpy.concatenate to merge two imgs, like

    stacked_array = np.concatenate((inpaint_resize, ref_image), axis=2)
    #inpaint_resize : (512,512,3)
    #ref_image : (512,512,3)

Then I get a (512,512,6) numpy as a hint. But there is something wrong.

  File "/root/autodl-tmp/ControlNet-v1-1-nightly/cldm/logger.py", line 40, in log_local
    Image.fromarray(grid).save(path)
  File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, in fromarray
    raise TypeError(msg) from e
TypeError: Cannot handle this data type: (1, 1, 6), |u1

I am trying to fix this problem. May I ask if you have done any operation other than modifying the hint_channels, or can you provide the part that you save tiff? I would be very, very grateful.

it is a meanless error, just overlook it for is aims to record the image log while training.

            if grid.shape[2] == 6:
                grid = grid[ :, :,:3]
                continue

add this code before File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, to skip it in cldm/logger.py

It makes sense! Thanks for your useful advice!

@SamanFekri
Copy link

Hey guys,
I also wants to use multiple controls for my thesis controlnet. Could you add the complete yaml config and also MyDataset Class for this problem?
If I get some result I will add mine

@SamanFekri
Copy link

SamanFekri commented Feb 16, 2024

I change model config to this:

model:
  target: cldm.scldm.ExtendedControlLDM
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    control_key: "hint"
    image_size: 64
    channels: 4 # changed from 4 to 7
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    only_mid_control: False

    control_stage_config:
      target: cldm.cldm.ControlNet
      params:
        image_size: 32 # unused
        in_channels: 4
        hint_channels: 9 # 3 for one image, 2 * 3 for 2 image hint, 3 * 3 for 3 images
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    unet_config:
      target: cldm.cldm.ControlledUnetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

I add the hints inside my Dataset class to load data from It and I change the DataSet class and concatenate different images to each other as a source

        source = cv2.imread(f'{self.dataset_path}/{source_filename}')
        target = cv2.imread(f'{self.dataset_path}/{target_filename}')

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
        
        # Add canny detector as second control
        detected_map = resize_image(HWC3(target), self.resolution)
        detected_map = self.apply_canny(detected_map, self.canny_low, self.canny_high)
        canny = HWC3(detected_map)
        
        # Resize the original Image
        resize = cv2.resize(target, self.small_dim, interpolation = cv2.INTER_AREA)
        resize = cv2.resize(resize, self.original_dim, interpolation = cv2.INTER_AREA)

        
        # concat the channels to source
        source = np.concatenate((source, resize, canny), axis=2)
        
        # Normalize source images to [0, 1].
        # source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

I started from a pretrained stable diffusion model. I need to load the weights inside the model hence I duplicate the weights in the hints and you can see it in the following code:

# Model Creation
model = create_model(config['model']['config_file']).cpu()

lsd = load_state_dict(resume_path, location='cpu')

# Convert the list of tensors to a single tensor
repeated_tensor = torch.stack([torch.tensor(item).repeat(1, config['model']['num_hints'], 1, 1) for item in lsd['control_model.input_hint_block.0.weight']]).squeeze(1)

# Assign the corrected tensor to the state dictionary
lsd['control_model.input_hint_block.0.weight'] = repeated_tensor

In the above code the config['model']['num_hints']=3

@bhosalems
Copy link

I change model config to this:

model:
  target: cldm.scldm.ExtendedControlLDM
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    control_key: "hint"
    image_size: 64
    channels: 4 # changed from 4 to 7
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    only_mid_control: False

    control_stage_config:
      target: cldm.cldm.ControlNet
      params:
        image_size: 32 # unused
        in_channels: 4
        hint_channels: 9 # 3 for one image, 2 * 3 for 2 image hint, 3 * 3 for 3 images
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    unet_config:
      target: cldm.cldm.ControlledUnetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

I add the hints inside my Dataset class to load data from It and I change the DataSet class and concatenate different images to each other as a source

        source = cv2.imread(f'{self.dataset_path}/{source_filename}')
        target = cv2.imread(f'{self.dataset_path}/{target_filename}')

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
        
        # Add canny detector as second control
        detected_map = resize_image(HWC3(target), self.resolution)
        detected_map = self.apply_canny(detected_map, self.canny_low, self.canny_high)
        canny = HWC3(detected_map)
        
        # Resize the original Image
        resize = cv2.resize(target, self.small_dim, interpolation = cv2.INTER_AREA)
        resize = cv2.resize(resize, self.original_dim, interpolation = cv2.INTER_AREA)

        
        # concat the channels to source
        source = np.concatenate((source, resize, canny), axis=2)
        
        # Normalize source images to [0, 1].
        # source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

I started from a pretrained stable diffusion model. I need to load the weights inside the model hence I duplicate the weights in the hints and you can see it in the following code:

# Model Creation
model = create_model(config['model']['config_file']).cpu()

lsd = load_state_dict(resume_path, location='cpu')

# Convert the list of tensors to a single tensor
repeated_tensor = torch.stack([torch.tensor(item).repeat(1, config['model']['num_hints'], 1, 1) for item in lsd['control_model.input_hint_block.0.weight']]).squeeze(1)

# Assign the corrected tensor to the state dictionary
lsd['control_model.input_hint_block.0.weight'] = repeated_tensor

In the above code the config['model']['num_hints']=3

Thank you for the inputs, may I know if you were able to get good results with this?

@jiachen0212
Copy link

I change model config to this:

model:
  target: cldm.scldm.ExtendedControlLDM
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    control_key: "hint"
    image_size: 64
    channels: 4 # changed from 4 to 7
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    only_mid_control: False

    control_stage_config:
      target: cldm.cldm.ControlNet
      params:
        image_size: 32 # unused
        in_channels: 4
        hint_channels: 9 # 3 for one image, 2 * 3 for 2 image hint, 3 * 3 for 3 images
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    unet_config:
      target: cldm.cldm.ControlledUnetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

I add the hints inside my Dataset class to load data from It and I change the DataSet class and concatenate different images to each other as a source

        source = cv2.imread(f'{self.dataset_path}/{source_filename}')
        target = cv2.imread(f'{self.dataset_path}/{target_filename}')

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
        
        # Add canny detector as second control
        detected_map = resize_image(HWC3(target), self.resolution)
        detected_map = self.apply_canny(detected_map, self.canny_low, self.canny_high)
        canny = HWC3(detected_map)
        
        # Resize the original Image
        resize = cv2.resize(target, self.small_dim, interpolation = cv2.INTER_AREA)
        resize = cv2.resize(resize, self.original_dim, interpolation = cv2.INTER_AREA)

        
        # concat the channels to source
        source = np.concatenate((source, resize, canny), axis=2)
        
        # Normalize source images to [0, 1].
        # source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

I started from a pretrained stable diffusion model. I need to load the weights inside the model hence I duplicate the weights in the hints and you can see it in the following code:

# Model Creation
model = create_model(config['model']['config_file']).cpu()

lsd = load_state_dict(resume_path, location='cpu')

# Convert the list of tensors to a single tensor
repeated_tensor = torch.stack([torch.tensor(item).repeat(1, config['model']['num_hints'], 1, 1) for item in lsd['control_model.input_hint_block.0.weight']]).squeeze(1)

# Assign the corrected tensor to the state dictionary
lsd['control_model.input_hint_block.0.weight'] = repeated_tensor

In the above code the config['model']['num_hints']=3

Thank you for the inputs, may I know if you were able to get good results with this?

Maybe we can't achieve the desired result. I tried segmap plus depth... If there are no other bugs in my experiment, then the conclusion is: the image is not ok~
Can you train two control volumes and get good results?

@Chanuku
Copy link

Chanuku commented Aug 30, 2024

Hi, I had similar concern and I solved the problem by duplicating the hint channel in pretraned model (2 duplicated image for hint)
Here is the code (edit batch size and num_gpu based on your setting)

from share import *
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from tutorial_dataset import MyDataset
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict
import torch

# Configs
resume_path = './models/control_sd15_ini.ckpt'
batch_size = 12
logger_freq = 300
learning_rate = 1e-5
sd_locked = True
only_mid_control = False

def modify_state_dict_for_6_channels(state_dict):
    for key in state_dict.keys():
        if 'input_hint_block' in key and 'weight' in key:
            old_weight = state_dict[key]
            if old_weight.shape[1] == 3:
                new_weight = torch.zeros(old_weight.shape[0], 6, old_weight.shape[2], old_weight.shape[3], 
                                         device=old_weight.device, dtype=old_weight.dtype)
                new_weight[:, :3, :, :] = old_weight
                new_weight[:, 3:, :, :] = old_weight  # Duplicate the weights for the additional channels
                state_dict[key] = new_weight
    return state_dict

if __name__ == '__main__':
    # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
    model = create_model('./models/cldm_v15.yaml').cpu()
    state_dict = load_state_dict(resume_path, location='cpu')

    # Modify the state_dict to handle the new input channels
    state_dict = modify_state_dict_for_6_channels(state_dict)

    # Load the modified state dict
    model.load_state_dict(state_dict, strict=False)

    model.learning_rate = learning_rate
    model.sd_locked = sd_locked
    model.only_mid_control = only_mid_control

    # Misc
    dataset = MyDataset()
    dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
    logger = ImageLogger(batch_frequency=logger_freq)
    trainer = pl.Trainer(gpus=2, precision=32, callbacks=[logger])

    # Train!
    trainer.fit(model, dataloader)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants