-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Thanks for your great work!
When attempting to upscale large images, I'm encountering GPU memory OOM errors, primarily in this function. In the attn2task function, what's the reasoning behind disabling the first two branches (AttnBlock and MemoryEfficientAttnBlock)? Would there be any issues if I enable the MemoryEfficientAttnBlock branch? Could that potentially reduce memory requirements?
def attn2task(task_queue, net):
if False: #isinstance(net, AttnBlock):
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.norm))
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
task_queue.append(['add_res', None])
elif False: #isinstance(net, MemoryEfficientAttnBlock):
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.norm))
task_queue.append(
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
task_queue.append(['add_res', None])
else:
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.group_norm))
task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
task_queue.append(['add_res', None])