Skip to content

Commit

Permalink
Add download script, fix training tab
Browse files Browse the repository at this point in the history
  • Loading branch information
lucataco committed Nov 8, 2023
1 parent f801c03 commit 54410ec
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 4 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
__pycache
__pycache__
.cog
blip-cache
blip-proc-cache
checkpoint
model-cache
swin2sr-cache
temp
temp_in
training_out
output.tar
trained_model.tar
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ First, download the pre-trained weights:

Then, you can run predictions:

cog train -i input_images=@zeke.zip -i use_face_detection_instead=True
cog predict -i input_images=@zeke.zip -i use_face_detection_instead=True

## Example:

Expand Down
3 changes: 1 addition & 2 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,4 @@ build:
- wget http://thegiflibrary.tumblr.com/post/11565547760 -O face_landmarker_v2_with_blendshapes.task -q https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
train: "train.py:train"
predict: "predict.py:Predictor"
25 changes: 25 additions & 0 deletions script/download-weights
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python

# Run this before you deploy it on replicate
import os
import sys
import torch
from diffusers import StableDiffusionXLPipeline

# append project directory to path so predict.py can be imported
sys.path.append('.')
from predict import MODEL_NAME, MODEL_CACHE

# Make cache folders
if not os.path.exists(MODEL_CACHE):
os.makedirs(MODEL_CACHE)

# SD-XL-Base-1.0 refiner
pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
)

pipe.save_pretrained(MODEL_CACHE, safe_serialization=True)

0 comments on commit 54410ec

Please sign in to comment.