Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

self-trained ControlNet is slower than standard #683

Open
Liuqh12 opened this issue Jun 13, 2024 · 0 comments
Open

self-trained ControlNet is slower than standard #683

Liuqh12 opened this issue Jun 13, 2024 · 0 comments

Comments

@Liuqh12
Copy link

Liuqh12 commented Jun 13, 2024

I train myself ControlNet according to tutorial_train.py.

After training, I got my_cn.ckpt, size about 8G.

my_cn.ckpt can load, run and get expected results by gradio_scribble2image.py , just update:

model.load_state_dict(load_state_dict('./models/my_cn.ckpt', location='cuda'))

However, during inference, I found my_cn is several times slower than yours huggingface.

I print state_dict in my_cn.ckpt and control_sd15_scribble.pth, both are torch.float32.

I test ControlNet alone, code as follow:

from share import *
import cv2
import torch
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from tqdm import tqdm
from torchinfo import summary

model = create_model('./models/cldm_v15.yaml').cpu()

# 317 ms
# 1445.12M-params
sketch_ckpt_path = './models/control_sd15_scribble.pth'

# 1545 ms
# 1445.12M-params
# sketch_ckpt_path = './models/my_cn.ckpt'

model.load_state_dict(load_state_dict(sketch_ckpt_path, location='cuda'))

model = model.cuda()
control_net = model.control_model

x, hint, timesteps, context = torch.rand((1,4,64,64)).to('cuda'), torch.rand((1,3,512,512)).to('cuda'), torch.rand((1)).to('cuda'), torch.rand((1,77,768)).to('cuda')

# print model information: https://github.com/TylerYep/torchinfo
summary(control_net, input_data=[x, hint, timesteps, context])

epoch = 50
e_sum = 0.00
for i in tqdm(range(0, epoch)):
    begin = cv2.getTickCount()
    control_net(x, hint, timesteps, context)
    end = cv2.getTickCount()
    # to ms
    e_sum += (end - begin) / cv2.getTickFrequency() * 1000.0
print(e_sum / epoch)
print("Done!")

I think I must have missed some details, looking forward to your suggestions.

@Liuqh12 Liuqh12 changed the title self-trained is slower than standard self-trained ControlNet is slower than standard Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant