In [None]:
import asyncio
import pickle
from typing import Literal
import math

import numpy as np
from IPython.display import display
from PIL import Image
import meadowrun
import matplotlib.pyplot as plt

# Overview
To run this notebook, follow the [accompanying blog post](https://medium.com/p/e8aef6f974c1). Or, following the quick start below:

```shell
# Clone this repo and create the local environment
git clone https://github.com/meadowdata/meadowrun-dallemini-demo
cd meadowrun-dallemini-demo
python3 -m venv venv
source venv/bin/activate
pip install -r local_requirements.txt

# Install meadowrun in your AWS account
meadowrun-manage-ec2 install --allow-authorize-ips
# Create an S3 bucket to cache pretrained models
aws s3 mb s3://meadowrun-dallemini
# Grant permission to Meadowrun to access this bucket
meadowrun-manage-ec2 grant-permission-to-s3-bucket meadowrun-dallemini

# Run a jupyter server
jupyter notebook
```

You'll also need to make sure your AWS account has non-zero quotas for at least some GPU instance types:
- L-3819A6DF: [All G and VT Spot Instance Requests](https://console.aws.amazon.com/servicequotas/home/services/ec2/quotas/L-3819A6DF)
- L-7212CCBC: [All P Spot Instance Requests](https://console.aws.amazon.com/servicequotas/home/services/ec2/quotas/L-7212CCBC)
- L-DB2E81BA: [Running On-Demand G and VT instances](https://console.aws.amazon.com/servicequotas/home/services/ec2/quotas/L-DB2E81BA)
- L-417A185B: [Running On-Demand P instances](https://console.aws.amazon.com/servicequotas/home/services/ec2/quotas/L-417A185B)

## Parameters for caching:

In [None]:
# You must set this to match your S3 bucket that you create (see Overview)
S3_BUCKET_NAME = "meadowrun-dallemini"
S3_BUCKET_REGION = "us-east-2"

In [None]:
# A function for showing a grid of images
def show_images(images):
    width = 20
    columns = 3
    n = len(images)
    rows = math.ceil(len(images) / columns)
    height = (width / columns) * rows
    f = plt.figure(figsize=(width, height))
    for i, image in enumerate(images):
        ax = f.add_subplot(rows, columns, i + 1)
        ax.set_title(str(i))
        ax.axis("off")
        plt.imshow(image)
    
    f.tight_layout()

In [None]:
# We have two deployments, one for caching and one for running models
async def caching_deployment():
    return await meadowrun.Deployment.mirror_local(
        interpreter=meadowrun.PipRequirementsFile("caching_requirements.txt", "3.9"))


async def model_deployment():
    return await meadowrun.Deployment.mirror_local(
        interpreter=meadowrun.PipRequirementsFile("model_requirements.txt", "3.8", ["libgl1", "libglib2.0-0"]))

In [None]:
# Cache the DALL·E Mini pre-trained model
await meadowrun.run_function(
    "linux.cache_in_s3.download_pretrained_dallemini_cache_in_s3",
    meadowrun.AllocCloudInstance("EC2"),
    meadowrun.Resources(1, 2, 80),
    await caching_deployment(),
    ["mega_full", S3_BUCKET_NAME, S3_BUCKET_REGION]
)

In [None]:
# Cache the glid-3-xl pre-trained model
await meadowrun.run_function(
    "linux.cache_in_s3.download_pretrained_gild3xl_cache_in_s3",
    meadowrun.AllocCloudInstance("EC2"),
    meadowrun.Resources(1, 2, 80),
    await caching_deployment(),
    [S3_BUCKET_NAME, S3_BUCKET_REGION]
)

In [None]:
# Cache the SwinIR pre-trained model
await meadowrun.run_function(
    "linux.cache_in_s3.download_pretrained_swinir_cache_in_s3",
    meadowrun.AllocCloudInstance("EC2"),
    meadowrun.Resources(1, 2, 80),
    await caching_deployment(),
    [S3_BUCKET_NAME, S3_BUCKET_REGION]
)

## Parameters for DALL·E Mini

In [None]:
prompt = "batman praying in the garden of gethsemane"
num_images = 8
# Options are mini, mega, mega_full
model_version = "mega_full"

In [None]:
gpu_memory_required = {"mini": 4, "mega": 8, "mega_full": 12}[model_version]
main_memory_required = {"mini": 16, "mega": 20, "mega_full": 24}[model_version]
model_ec2_instance_requirements = meadowrun.Resources(
    1, main_memory_required, 80, gpu_memory=gpu_memory_required, flags="nvidia"
)

In [None]:
saved = []

In [None]:
dallemini_images = await meadowrun.run_function(
    "linux.dalle_wrapper.generate_images_api",
    meadowrun.AllocCloudInstance("EC2"),
    model_ec2_instance_requirements,
    await model_deployment(),
    [model_version, prompt, num_images, S3_BUCKET_NAME, S3_BUCKET_REGION]
)
saved.append(dallemini_images)

In [None]:
show_images(dallemini_images)

## Parameters for glid-3-xl

In [None]:
chosen_image = dallemini_images[6]
num_images = 8

In [None]:
glid3xl_images = await meadowrun.run_function(
    "linux.glid3xl_wrapper.do_run",
    meadowrun.AllocCloudInstance("EC2"),
    model_ec2_instance_requirements,
    await model_deployment(),
    [Image.fromarray(chosen_image), S3_BUCKET_NAME, S3_BUCKET_REGION, prompt, num_images]
)
glid3xl_images = [i[0] for i in glid3xl_images]

In [None]:
show_images(glid3xl_images)

## Parameters for SwinIR

In [None]:
chosen_image = glid3xl_images[0]

In [None]:
image = await meadowrun.run_function(
    "linux.swinir_wrapper.main",
    meadowrun.AllocCloudInstance("EC2"),
    model_ec2_instance_requirements,
    await model_deployment(),
    [np.asarray(chosen_image), S3_BUCKET_NAME, S3_BUCKET_REGION]
)

In [None]:
final_image = Image.fromarray(image)
final_image