In [None]:
! pip install git+https://github.com/openai/CLIP

In [None]:
from tqdm.notebook import trange
import torch
from torchvision import transforms
import clip
import Path

In [2]:
def get_embedding(device, shape, lr, prompt, clip_model):

    image=torch.rand((1, 3, shape[0], shape[1]), device=device, requires_grad=True)

    opt=torch.optim.Adam((image,),lr)

    f=transforms.Compose([lambda x:torch.clamp((x+1)/2,min=0,max=1),transforms.RandomAffine(degrees=60, translate=(0.1, 0.1)),transforms.RandomGrayscale(p=0.2),
                        transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.01),transforms.Resize(224)])

    m=clip.load(clip_model, jit=False)[0].eval().requires_grad_(False).to(device)

    embedding=m.encode_text(clip.tokenize(prompt).to(device))

    return image, opt, f, m, embedding

def total_variation_loss(img):
    yv = torch.pow(img[:,:,1:,:]-img[:,:,:-1,:], 2).sum()
    xv = torch.pow(img[:,:,:,1:]-img[:,:,:,:-1], 2).sum()
    return (yv+xv)/(1*3*shape[0]*shape[1])

def spherical_distance_loss(x, y):
    return (torch.nn.functional.normalize(x, dim=-1) - torch.nn.functional.normalize(y, dim=-1)).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean()

In [None]:
# Test the model on a prompt

device ='cpu'
cutn =  16
shape = (256, 256) 
clip_model = "ViT-B/32"
lr = 0.03
prompt = "Paste here your prompt to generate an image"
steps = 500

image, opt, f, m, embedding = get_embedding(device, shape, lr, prompt, clip_model)
for i in trange(steps):
    opt.zero_grad()
    clip_in = m.encode_image(torch.cat([f(image.add(1).div(2)) for _ in range(cutn)]))
    loss = spherical_distance_loss(clip_in, embedding.unsqueeze(0)) + (image - image.clamp(-1, 1)).pow(2).mean()/2 + total_variation_loss(image)
    loss.backward()
    opt.step()

img = transforms.ToPILImage()(image.squeeze(0).clamp(-1,1)/2+.5)
img.show()


In [None]:
# Generate multiple images with different parameters and prompts, then save them.
# Learning rate, number of steps and part of the prompt are saved in the name. 

Path("/gen_imgs").mkdir(parents=True, exist_ok=True)

device ='cpu'
cutn =  16
shape = (256, 256) 
clip_model = "ViT-B/32"

for document in ['patient utterances.txt', 'interpretation of G. Bateson.txt']:
    with open(document, 'r') as file:
        lines = file.readlines()
        if lines != []:
            for prompt in lines:
                prompt = prompt[:-2]

                for lr in [0.02, 0.03, 0.04]:
                    for steps in [400, 500, 600]:

                        image, opt, f, m, embedding = get_embedding(device, shape, lr, prompt, clip_model)
                        for i in trange(steps):
                            opt.zero_grad()
                            clip_in = m.encode_image(torch.cat([f(image.add(1).div(2)) for _ in range(cutn)]))
                            loss = spherical_distance_loss(clip_in, embedding.unsqueeze(0)) + (image - image.clamp(-1, 1)).pow(2).mean()/2 + total_variation_loss(image)
                            loss.backward()
                            opt.step()
                        
                        img = transforms.ToPILImage()(image.squeeze(0).clamp(-1,1)/2+.5)
                        img.save(f'gen_imgs/{document[:10]} {prompt[:35]} lr:{lr} steps:{steps}.png')
    