238 changes: 238 additions & 0 deletions big_sleep/clip.py
@@ -0,0 +1,238 @@
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

import hashlib
import os
import urllib
import warnings

from PIL import Image
from tqdm import tqdm
from pathlib import Path

import html
import os
from functools import lru_cache
from collections import OrderedDict

import ftfy
import regex as re

MODEL_PATH = "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"

@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt")

@lru_cache()
def bytes_to_unicode():
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

def get_pairs(word):
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs

def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()

def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text

class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = Path(bpe_path).read_text().split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)

if not pairs:
return token+'</w>'

while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break

if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word

def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens

def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text

def _download(url, root = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)

expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)

if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")

if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
loop.update(len(buffer))

if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")

return download_target

normalize_image = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

def load(device = ("cuda" if torch.cuda.is_available() else "cpu")):
model_path = _download(MODEL_PATH)
model = torch.jit.load(model_path, map_location = device).eval()
n_px = model.input_resolution.item()

transform = Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
normalize_image,
])

# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

def patch_device(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)

for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)

model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)

# patch dtype to float32 on CPU
if device == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()

def patch_float(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)

for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)

model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)

model.float()

return model, transform

tokenizer = SimpleTokenizer()

def tokenize(texts, context_length: int = 77):
if isinstance(texts, str):
texts = [texts]

sot_token = tokenizer.encoder["<|startoftext|>"]
eot_token = tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)

return result
262,145 changes: 262,145 additions & 0 deletions big_sleep/data/bpe_simple_vocab_16e6.txt

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions big_sleep/version.py
@@ -0,0 +1 @@
__version__ = '0.0.1'
43 changes: 43 additions & 0 deletions setup.py
@@ -0,0 +1,43 @@
import sys
from setuptools import setup, find_packages

sys.path[0:0] = ['big_sleep']
from version import __version__

setup(
name = 'big-sleep',
packages = find_packages(),
include_package_data = True,
entry_points={
'console_scripts': [
'imagine_big = big_sleep.cli:main',
],
},
version = __version__,
license='MIT',
description = 'Big Sleep',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/big-sleep',
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'text to image',
'generative adversarial networks'
],
install_requires=[
'torch>=1.7.1',
'einops>=0.3',
'ftfy',
'pytorch-pretrained-biggan',
'tqdm'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)