Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Intel Macs, standalone GPUs, add requirements.txt #14

Merged
merged 3 commits into from Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion maple-convert.py
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/usr/bin/env python3
import sys
if len(sys.argv) < 2: raise ValueError(f"Usage: {sys.argv[0]} path_to_ckpt")

Expand Down
34 changes: 19 additions & 15 deletions maple-diffusion/MapleDiffusion.swift
@@ -1,15 +1,16 @@
import MetalPerformanceShadersGraph
import Foundation
import Accelerate
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import is no longer needed, right? Other than that LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah, sorry, missed it:)


// Maple Diffusion implements stable diffusion (original v1.4 model)
// inference via MPSGraph. iOS has a hard memory limit of 4GB (with
// a special entitlement), so this implementation trades off latency
// for memory usage in many places (tagged with MEM-HACK) in order to
// stay under the limit and minimize probability of oom.

func makeGraph() -> MPSGraph {
func makeGraph(synchonize: Bool) -> MPSGraph {
let graph = MPSGraph()
graph.options = MPSGraphOptions.none
graph.options = synchonize ? MPSGraphOptions.synchronizeResults : .none
return graph
}

Expand Down Expand Up @@ -625,6 +626,7 @@ class MapleDiffusion {
let graphDevice: MPSGraphDevice
let commandQueue: MTLCommandQueue
let saveMemory: Bool
let shouldSynchronize: Bool

// text tokenization
let tokenizer: BPETokenizer
Expand Down Expand Up @@ -666,17 +668,18 @@ class MapleDiffusion {
device = MTLCreateSystemDefaultDevice()!
graphDevice = MPSGraphDevice(mtlDevice: device)
commandQueue = device.makeCommandQueue()!
shouldSynchronize = !device.hasUnifiedMemory

// text tokenization
tokenizer = BPETokenizer()

// time embedding
tembGraph = makeGraph()
tembGraph = makeGraph(synchonize: shouldSynchronize)
tembTIn = tembGraph.placeholder(shape: [1], dataType: MPSDataType.int32, name: nil)
tembOut = makeTimeFeatures(graph: tembGraph, tIn: tembTIn)

// diffusion
diffGraph = makeGraph()
diffGraph = makeGraph(synchonize: shouldSynchronize)
diffXIn = diffGraph.placeholder(shape: [1, height, width, 4], dataType: MPSDataType.float16, name: nil)
diffEtaUncondIn = diffGraph.placeholder(shape: [1, height, width, 4], dataType: MPSDataType.float16, name: nil)
diffEtaCondIn = diffGraph.placeholder(shape: [1, height, width, 4], dataType: MPSDataType.float16, name: nil)
Expand All @@ -703,7 +706,7 @@ class MapleDiffusion {
}

private func initTextGuidance() {
let graph = makeGraph()
let graph = makeGraph(synchonize: shouldSynchronize)
let textGuidanceIn = graph.placeholder(shape: [2, 77], dataType: MPSDataType.int32, name: nil)
let textGuidanceOut = makeTextGuidance(graph: graph, xIn: textGuidanceIn, name: "cond_stage_model.transformer.text_model")
let textGuidanceOut0 = graph.sliceTensor(textGuidanceOut, dimension: 0, start: 0, length: 1, name: nil)
Expand All @@ -712,7 +715,7 @@ class MapleDiffusion {
}

private func initAnUnexpectedJourney() {
let graph = makeGraph()
let graph = makeGraph(synchonize: shouldSynchronize)
let xIn = graph.placeholder(shape: [1, height, width, 4], dataType: MPSDataType.float16, name: nil)
let condIn = graph.placeholder(shape: [saveMemory ? 1 : 2, 77, 768], dataType: MPSDataType.float16, name: nil)
let tembIn = graph.placeholder(shape: [1, 320], dataType: MPSDataType.float16, name: nil)
Expand All @@ -723,7 +726,7 @@ class MapleDiffusion {
}

private func initTheDesolationOfSmaug() {
let graph = makeGraph()
let graph = makeGraph(synchonize: shouldSynchronize)
let condIn = graph.placeholder(shape: [saveMemory ? 1 : 2, 77, 768], dataType: MPSDataType.float16, name: nil)
let placeholders = anUnexpectedJourneyShapes.map{graph.placeholder(shape: $0, dataType: MPSDataType.float16, name: nil)} + [condIn]
theDesolationOfSmaugIndices.removeAll()
Expand All @@ -737,7 +740,7 @@ class MapleDiffusion {
}

private func initTheBattleOfTheFiveArmies() {
let graph = makeGraph()
let graph = makeGraph(synchonize: shouldSynchronize)
let condIn = graph.placeholder(shape: [saveMemory ? 1 : 2, 77, 768], dataType: MPSDataType.float16, name: nil)
let unetPlaceholders = theDesolationOfSmaugShapes.map{graph.placeholder(shape: $0, dataType: MPSDataType.float16, name: nil)} + [condIn]
theBattleOfTheFiveArmiesIndices.removeAll()
Expand All @@ -750,9 +753,9 @@ class MapleDiffusion {
}

private func randomLatent(seed: Int) -> MPSGraphTensorData {
let graph = makeGraph()
let graph = makeGraph(synchonize: shouldSynchronize)
let out = graph.randomTensor(withShape: [1, height, width, 4], descriptor: MPSGraphRandomOpDescriptor(distribution: .normal, dataType: .float16)!, seed: seed, name: nil)
return graph.run(feeds: [:], targetTensors: [out], targetOperations: nil)[out]!
return graph.run(with: commandQueue, feeds: [:], targetTensors: [out], targetOperations: nil)[out]!
}

private func runTextGuidance(baseTokens: [Int], tokens: [Int]) -> (MPSGraphTensorData, MPSGraphTensorData) {
Expand All @@ -765,7 +768,7 @@ class MapleDiffusion {
private func loadDecoderAndGetFinalImage(xIn: MPSGraphTensorData) -> MPSGraphTensorData {
// MEM-HACK: decoder is loaded from disc and deallocated to save memory (at cost of latency)
let x = xIn
let decoderGraph = makeGraph()
let decoderGraph = makeGraph(synchonize: shouldSynchronize)
let decoderIn = decoderGraph.placeholder(shape: x.shape, dataType: MPSDataType.float16, name: nil)
let decoderOut = makeDecoder(graph: decoderGraph, xIn: decoderIn)
return decoderGraph.run(with: commandQueue, feeds: [decoderIn: x], targetTensors: [decoderOut], targetOperations: nil)[decoderOut]!
Expand Down Expand Up @@ -807,19 +810,19 @@ class MapleDiffusion {

private func runBatchedUNet(latent: MPSGraphTensorData, baseGuidance: MPSGraphTensorData, textGuidance: MPSGraphTensorData, temb: MPSGraphTensorData) -> (MPSGraphTensorData, MPSGraphTensorData) {
// concat
var graph = makeGraph()
var graph = makeGraph(synchonize: shouldSynchronize)
let bg = graph.placeholder(shape: baseGuidance.shape, dataType: MPSDataType.float16, name: nil)
let tg = graph.placeholder(shape: textGuidance.shape, dataType: MPSDataType.float16, name: nil)
let concatGuidance = graph.concatTensors([bg, tg], dimension: 0, name: nil)
let concatGuidanceData = graph.run(feeds: [bg : baseGuidance, tg: textGuidance], targetTensors: [concatGuidance], targetOperations: nil)[concatGuidance]!
let concatGuidanceData = graph.run(with: commandQueue, feeds: [bg : baseGuidance, tg: textGuidance], targetTensors: [concatGuidance], targetOperations: nil)[concatGuidance]!
// run
let concatEtaData = runUNet(latent: latent, guidance: concatGuidanceData, temb: temb)
// split
graph = makeGraph()
graph = makeGraph(synchonize: shouldSynchronize)
let etas = graph.placeholder(shape: concatEtaData.shape, dataType: concatEtaData.dataType, name: nil)
let eta0 = graph.sliceTensor(etas, dimension: 0, start: 0, length: 1, name: nil)
let eta1 = graph.sliceTensor(etas, dimension: 0, start: 1, length: 1, name: nil)
let etaRes = graph.run(feeds: [etas: concatEtaData], targetTensors: [eta0, eta1], targetOperations: nil)
let etaRes = graph.run(with: commandQueue, feeds: [etas: concatEtaData], targetTensors: [eta0, eta1], targetOperations: nil)
return (etaRes[eta0]!, etaRes[eta1]!)
}

Expand Down Expand Up @@ -908,3 +911,4 @@ func tensorToCGImage(data: MPSGraphTensorData) -> CGImage {
data.mpsndarray().readBytes(&imageArrayCPUBytes, strideBytes: nil)
return CGImage(width: shape[2], height: shape[1], bitsPerComponent: 8, bitsPerPixel: 32, bytesPerRow: shape[2]*shape[3], space: CGColorSpaceCreateDeviceRGB(), bitmapInfo: CGBitmapInfo(rawValue: CGBitmapInfo.byteOrder32Big.rawValue | CGImageAlphaInfo.noneSkipLast.rawValue), provider: CGDataProvider(data: NSData(bytes: &imageArrayCPUBytes, length: imageArrayCPUBytes.count))!, decode: nil, shouldInterpolate: true, intent: CGColorRenderingIntent.defaultIntent)!
}

4 changes: 4 additions & 0 deletions requirements.txt
@@ -0,0 +1,4 @@
numpy>=1.23
pytorch-lightning>=1.7
requests>=2.28
torch>=1.12